client/
websocket.rs

1// src/client/websocket.rs
2//! WebSocket client for Bilibili live danmaku messages (refactored from bili_live_dm)
3
4use native_tls::TlsStream;
5use serde_json::Value;
6use std::net::TcpStream;
7use tungstenite::{client, Message, WebSocket};
8
9use url::Url;
10
11use futures_channel::mpsc::Sender;
12use http::Response;
13use std::collections::HashMap;
14
15use crate::auth::*;
16use crate::models::{AuthMessage, BiliMessage, DanmuServer, MsgHead};
17
18pub struct BiliLiveClient {
19    ws: WebSocket<TlsStream<TcpStream>>,
20    auth_msg: String,
21    ss: Sender<BiliMessage>,
22}
23
24impl BiliLiveClient {
25    pub fn new(cookies: &str, room_id: &str, r: Sender<BiliMessage>) -> Self {
26        let (v, auth) = init_server(cookies, room_id);
27        let (ws, _res) = connect(v["host_list"].clone());
28        BiliLiveClient {
29            ws,
30            auth_msg: serde_json::to_string(&auth).unwrap(),
31            ss: r,
32        }
33    }
34
35    /// Create a new client with automatic browser cookie detection
36    /// If cookies is None or empty, it will try to find cookies from browser
37    pub fn new_auto(
38        cookies: Option<&str>,
39        room_id: &str,
40        r: Sender<BiliMessage>,
41    ) -> Result<Self, String> {
42        let (v, auth) = init_server_auto(cookies, room_id)?;
43        let (ws, _res) = connect(v["host_list"].clone());
44        Ok(BiliLiveClient {
45            ws,
46            auth_msg: serde_json::to_string(&auth).unwrap(),
47            ss: r,
48        })
49    }
50
51    pub fn send_auth(&mut self) {
52        let _ = self.ws.send(Message::Binary(make_packet(
53            self.auth_msg.as_str(),
54            Operation::AUTH,
55        )));
56    }
57
58    pub fn send_heart_beat(&mut self) {
59        let _ = self
60            .ws
61            .send(Message::Binary(make_packet("{}", Operation::HEARTBEAT)));
62    }
63
64    pub fn parse_ws_message(&mut self, resv: Vec<u8>) {
65        let mut offset = 0;
66        let header = &resv[0..16];
67        let mut head_1 = get_msg_header(header);
68        if head_1.operation == 5 || head_1.operation == 8 {
69            loop {
70                let body: &[u8] = &resv[offset + 16..offset + (head_1.pack_len as usize)];
71                self.parse_business_message(head_1, body);
72                offset += head_1.pack_len as usize;
73                if offset >= resv.len() {
74                    break;
75                }
76                let temp_head = &resv[offset..(offset + 16)];
77                head_1 = get_msg_header(temp_head);
78            }
79        } else if head_1.operation == 3 {
80            let mut body: [u8; 4] = [0, 0, 0, 0];
81            body[0] = resv[16];
82            body[1] = resv[17];
83            body[2] = resv[18];
84            body[3] = resv[19];
85            let popularity = i32::from_be_bytes(body);
86            log::info!("popularity:{}", popularity);
87        } else {
88            log::error!(
89                "unknown message operation={:?}, header={:?}}}",
90                head_1.operation,
91                head_1
92            )
93        }
94    }
95
96    pub fn parse_business_message(&mut self, h: MsgHead, b: &[u8]) {
97        if h.operation == 5 {
98            if h.ver == 3 {
99                let res: Vec<u8> = decompress(b).unwrap();
100                self.parse_ws_message(res);
101            } else if h.ver == 0 {
102                let s = String::from_utf8(b.to_vec()).unwrap();
103                let res_json: Value = serde_json::from_str(s.as_str()).unwrap();
104                if let Some(msg) = handle(res_json) {
105                    if let BiliMessage::Unsupported = msg {
106                        return;
107                    }
108                    let _ = self.ss.try_send(msg);
109                }
110            } else {
111                log::error!("Unknown compression format");
112            }
113        } else if h.operation == 8 {
114            self.send_heart_beat();
115        } else {
116            log::error!("Unknown message format {}", h.operation);
117        }
118    }
119
120    pub fn recive(&mut self) {
121        if self.ws.can_read() {
122            let msg = self.ws.read();
123            match msg {
124                Ok(m) => {
125                    let res = m.into_data();
126                    if res.len() >= 16 {
127                        self.parse_ws_message(res);
128                    }
129                }
130                Err(_) => {
131                    panic!("read msg error");
132                }
133            }
134        }
135    }
136}
137
138pub fn gen_damu_list(list: &Value) -> Vec<DanmuServer> {
139    let server_list = list.as_array().unwrap();
140    let mut res: Vec<DanmuServer> = Vec::new();
141    if server_list.len() == 0 {
142        let d = DanmuServer::default();
143        res.push(d);
144    }
145    for s in server_list {
146        res.push(DanmuServer {
147            host: s["host"].as_str().unwrap().to_string(),
148            port: s["port"].as_u64().unwrap() as i32,
149            wss_port: s["wss_port"].as_u64().unwrap() as i32,
150            ws_port: s["ws_port"].as_u64().unwrap() as i32,
151        });
152    }
153    res
154}
155
156fn find_server(vd: Vec<DanmuServer>) -> (String, String, String) {
157    let (host, wss_port) = (vd.get(0).unwrap().host.clone(), vd.get(0).unwrap().wss_port);
158    (
159        host.clone(),
160        format!("{}:{}", host.clone(), wss_port),
161        format!("wss://{}:{}/sub", host, wss_port),
162    )
163}
164
165pub fn init_server(cookies: &str, room_id: &str) -> (Value, AuthMessage) {
166    let mut auth_map = HashMap::new();
167    let mut headers = reqwest::header::HeaderMap::new();
168    headers.insert(
169        reqwest::header::COOKIE,
170        reqwest::header::HeaderValue::from_str(cookies).unwrap(),
171    );
172    headers.insert(
173        reqwest::header::USER_AGENT,
174        reqwest::header::HeaderValue::from_static(crate::auth::USER_AGENT),
175    );
176    log::debug!("headers: {:?}", headers);
177
178    // Extract SESSDATA from cookies for authentication
179    let sessdata = cookies
180        .split(';')
181        .find_map(|kv| {
182            let mut parts = kv.trim().splitn(2, '=');
183            let key = parts.next()?.trim();
184            let value = parts.next()?.trim();
185            if key == "SESSDATA" {
186                Some(value.to_string())
187            } else {
188                None
189            }
190        })
191        .unwrap_or_else(|| "".to_string());
192
193    if !sessdata.is_empty() {
194        let (_, body1) = init_uid(headers.clone());
195        let body1_v: Value = serde_json::from_str(body1.as_str()).unwrap();
196
197        // Check if the authentication was successful
198        if let Some(mid) = body1_v["data"]["mid"].as_i64() {
199            auth_map.insert("uid".to_string(), mid.to_string());
200            log::info!("Successfully authenticated with uid: {}", mid);
201        } else {
202            log::warn!("Authentication failed - SESSDATA may be invalid or expired");
203            log::debug!("Auth response: {}", body1);
204            auth_map.insert("uid".to_string(), "0".to_string());
205        }
206    } else {
207        auth_map.insert("uid".to_string(), "0".to_string());
208    }
209    // here the live room id is easily obtained, so we not get it by url.
210    auth_map.insert("room_id".to_string(), room_id.to_string());
211
212    let room_id_num = room_id.parse::<u64>().expect("room_id must be a valid u64");
213    let (_, body4) = init_host_server(headers.clone(), room_id_num);
214    let body4_res: Value = serde_json::from_str(body4.as_str()).unwrap();
215    let server_info = &body4_res["data"];
216    let token = &body4_res["data"]["token"].as_str().unwrap();
217    auth_map.insert("token".to_string(), token.to_string());
218
219    let auth_msg = AuthMessage::from(&auth_map);
220    (server_info.clone(), auth_msg)
221}
222
223pub fn connect(v: Value) -> (WebSocket<TlsStream<TcpStream>>, Response<Option<Vec<u8>>>) {
224    let danmu_server = gen_damu_list(&v);
225    let (host, url, ws_url) = find_server(danmu_server);
226    let connector: native_tls::TlsConnector = native_tls::TlsConnector::new().unwrap();
227    let stream: TcpStream = TcpStream::connect(url).unwrap();
228    let stream: native_tls::TlsStream<TcpStream> =
229        connector.connect(host.as_str(), stream).unwrap();
230    let (socket, resp) =
231        client(Url::parse(ws_url.as_str()).unwrap(), stream).expect("Can't connect");
232    (socket, resp)
233}
234
235pub enum Operation {
236    AUTH,
237    HEARTBEAT,
238}
239
240pub fn make_packet(body: &str, ops: Operation) -> Vec<u8> {
241    let json: Value = serde_json::from_str(body).unwrap();
242    let temp = json.to_string();
243    let body_content: &[u8] = temp.as_bytes();
244    let pack_len: [u8; 4] = ((16 + body.len()) as u32).to_be_bytes();
245    let raw_header_size: [u8; 2] = (16 as u16).to_be_bytes();
246    let ver: [u8; 2] = (1 as u16).to_be_bytes();
247    let operation: [u8; 4] = match ops {
248        Operation::AUTH => (7 as u32).to_be_bytes(),
249        Operation::HEARTBEAT => (2 as u32).to_be_bytes(),
250    };
251    let seq_id: [u8; 4] = (1 as u32).to_be_bytes();
252    let mut res = pack_len.to_vec();
253    res.append(&mut raw_header_size.to_vec());
254    res.append(&mut ver.to_vec());
255    res.append(&mut operation.to_vec());
256    res.append(&mut seq_id.to_vec());
257    res.append(&mut body_content.to_vec());
258    res
259}
260
261pub fn get_msg_header(v_s: &[u8]) -> MsgHead {
262    let mut pack_len: [u8; 4] = [0; 4];
263    let mut raw_header_size: [u8; 2] = [0; 2];
264    let mut ver: [u8; 2] = [0; 2];
265    let mut operation: [u8; 4] = [0; 4];
266    let mut seq_id: [u8; 4] = [0; 4];
267    for (i, v) in v_s.iter().enumerate() {
268        if i < 4 {
269            pack_len[i] = *v;
270            continue;
271        }
272        if i < 6 {
273            raw_header_size[i - 4] = *v;
274            continue;
275        }
276        if i < 8 {
277            ver[i - 6] = *v;
278            continue;
279        }
280        if i < 12 {
281            operation[i - 8] = *v;
282            continue;
283        }
284        if i < 16 {
285            seq_id[i - 12] = *v;
286            continue;
287        }
288    }
289    MsgHead {
290        pack_len: u32::from_be_bytes(pack_len),
291        raw_header_size: u16::from_be_bytes(raw_header_size),
292        ver: u16::from_be_bytes(ver),
293        operation: u32::from_be_bytes(operation),
294        seq_id: u32::from_be_bytes(seq_id),
295    }
296}
297
298pub fn decompress(body: &[u8]) -> std::io::Result<Vec<u8>> {
299    use brotlic::DecompressorReader;
300    use std::io::Read;
301    let mut decompressed_reader: DecompressorReader<&[u8]> = DecompressorReader::new(body);
302    let mut decoded_input = Vec::new();
303    let _ = decompressed_reader.read_to_end(&mut decoded_input)?;
304    Ok(decoded_input)
305}
306
307/// here we detail [info format is online](https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/message_stream.md)
308/// .
309pub fn handle(json: Value) -> Option<BiliMessage> {
310    let category = json["cmd"].as_str().unwrap_or("");
311    match category {
312        "DANMU_MSG" => Some(BiliMessage::Danmu {
313            user: json["info"][2][1]
314                .as_str()
315                .unwrap_or("<unknown>")
316                .to_string(),
317            text: json["info"][1].as_str().unwrap_or("").to_string(),
318        }),
319        "SEND_GIFT" => Some(BiliMessage::Gift {
320            user: json["info"][2][1]
321                .as_str()
322                .unwrap_or("<unknown>")
323                .to_string(),
324            gift: json["info"][1].as_str().unwrap_or("").to_string(),
325        }),
326        // Add more cases for other types as needed
327        _ => Some(BiliMessage::Unsupported),
328    }
329}
330
331/// Enhanced init_server that can automatically detect cookies from browser
332pub fn init_server_auto(
333    provided_cookies: Option<&str>,
334    room_id: &str,
335) -> Result<(Value, AuthMessage), String> {
336    // Try to get cookies from provided value or browser cookies
337    let cookies = get_cookies_or_browser(provided_cookies)
338        .ok_or_else(|| "No cookies found in provided value or browser cookies. Please log into bilibili.com in your browser or provide cookies manually.".to_string())?;
339
340    log::info!(
341        "Using cookies for authentication: {}...",
342        &cookies[..10.min(cookies.len())]
343    );
344
345    let result = init_server(&cookies, room_id);
346    Ok(result)
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use futures_channel::mpsc::channel;
353
354    #[test]
355    fn test_bili_live_client_connect() {
356        // Always enable debug log output for test
357        let _ = env_logger::builder()
358            .is_test(true)
359            .filter_level(log::LevelFilter::Debug)
360            .try_init();
361        // Get cookies from environment variable for real test
362        let cookies =
363            std::env::var("Cookie").unwrap_or_else(|_| "SESSDATA=dummy_sessdata".to_string());
364        let room_id = "24779526";
365        let (tx, _rx) = channel(10);
366        let _client = BiliLiveClient::new(&cookies, room_id, tx);
367    }
368}