edge_gpt/
session.rs

1use crate::{conversation_meta, ConversationMeta, CookieInFile};
2use async_stream::try_stream;
3use base64::{engine::general_purpose, Engine};
4use futures_util::{SinkExt, Stream, StreamExt};
5use rand::{distributions::Slice, Rng};
6use reqwest::header::{HeaderMap, HeaderValue};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use std::pin::Pin;
10use thiserror::Error;
11use tokio::net::TcpStream;
12use tokio_tungstenite::{
13    connect_async,
14    tungstenite::{
15        http::{self},
16        Message,
17    },
18    MaybeTlsStream, WebSocketStream,
19};
20use uuid::Uuid;
21
22const DELIMITER: u8 = 0x1e;
23
24fn random_hex_string(length: usize) -> String {
25    let hex_charactors: Vec<char> = "0123456789abcdef".chars().collect();
26    rand::thread_rng()
27        .sample_iter(Slice::new(&hex_charactors).unwrap())
28        .take(length)
29        .collect()
30}
31
32fn headers(uuid: &str, forwarded_ip: &str) -> HeaderMap {
33    let mut headers = HeaderMap::new();
34    headers.insert("accept", HeaderValue::from_static("application/json"));
35    headers.insert(
36        "accept-language",
37        HeaderValue::from_static("en-US,en;q=0.9"),
38    );
39    headers.insert("content-type", HeaderValue::from_static("application/json"));
40    headers.insert(
41        "sec-ch-ua",
42        HeaderValue::from_static(
43            "\"Not_A Brand\";v=\"99\", \"Microsoft Edge\";v=\"110\", \"Chromium\";v=\"110\"",
44        ),
45    );
46    headers.insert("sec-ch-ua-arch", HeaderValue::from_static("\"x86\""));
47    headers.insert("sec-ch-ua-bitness", HeaderValue::from_static("\"64\""));
48    headers.insert(
49        "sec-ch-ua-full-version",
50        HeaderValue::from_static("\"109.0.1518.78\""),
51    );
52    headers.insert(
53        "sec-ch-ua-full-version-list",
54        HeaderValue::from_static(
55            "\"Chromium\";v=\"110.0.5481.192\", \"Not A(Brand\";v=\"24.0.0.0\", \"Microsoft Edge\";v=\"110.0.1587.69\"",
56        ),
57    );
58    headers.insert("sec-ch-ua-mobile", HeaderValue::from_static("?0"));
59    headers.insert("sec-ch-ua-model", HeaderValue::from_static(""));
60    headers.insert(
61        "sec-ch-ua-platform",
62        HeaderValue::from_static("\"Windows\""),
63    );
64    headers.insert(
65        "sec-ch-ua-platform-version",
66        HeaderValue::from_static("\"15.0.0\""),
67    );
68    headers.insert("sec-fetch-dest", HeaderValue::from_static("empty"));
69    headers.insert("sec-fetch-mode", HeaderValue::from_static("cors"));
70    headers.insert("sec-fetch-site", HeaderValue::from_static("same-origin"));
71    headers.insert(
72        "x-ms-useragent",
73        HeaderValue::from_static(
74            "azsdk-js-api-client-factory/1.0.0-beta.1 core-rest-pipeline/1.10.0 OS/Win32",
75        ),
76    );
77    headers.insert(
78        "Referer",
79        HeaderValue::from_static("https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx"),
80    );
81    headers.insert(
82        "Referrer-Policy",
83        HeaderValue::from_static("origin-when-cross-origin"),
84    );
85
86    headers.insert(
87        "x-ms-client-request-id",
88        HeaderValue::from_str(uuid).unwrap(),
89    );
90    headers.insert(
91        "x-forwarded-for",
92        HeaderValue::from_str(forwarded_ip).unwrap(),
93    );
94    let websocket_key = random_hex_string(16);
95    let websocket_key_base64 = general_purpose::STANDARD.encode(websocket_key);
96    headers.insert(
97        "Sec-websocket-key",
98        HeaderValue::from_str(&websocket_key_base64).unwrap(),
99    );
100    headers.insert("Sec-WebSocket-Version", HeaderValue::from_static("13"));
101    headers.insert("Connection", HeaderValue::from_static("Upgrade"));
102    headers.insert("Upgrade", HeaderValue::from_static("websocket"));
103    headers.insert("Host", HeaderValue::from_static("sydney.bing.com"));
104    headers
105}
106
107fn random_forwarded_ip() -> String {
108    let mut rng = rand::thread_rng();
109    format!(
110        "13.{}.{}.{}",
111        rng.gen_range(104u8..=107u8),
112        rng.gen_range(0u8..=255),
113        rng.gen_range(0u8..=255)
114    )
115}
116
117/// Conversation Style of bing.
118#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
119pub enum ConversationStyle {
120    #[serde(rename = "h3imaginative")]
121    Creative,
122    #[serde(rename = "galileo")]
123    Balanced,
124    #[serde(rename = "h3precise")]
125    Precise,
126}
127
128impl From<ConversationStyle> for &'static str {
129    fn from(val: ConversationStyle) -> Self {
130        match val {
131            ConversationStyle::Creative => "h3imaginative",
132            ConversationStyle::Balanced => "galileo",
133            ConversationStyle::Precise => "h3precise",
134        }
135    }
136}
137
138/// A session represent a chat with bing.
139/// It implements `Serialize` and `Deserialize`
140/// Thus can be dumped to/load from external storage to pause and continue a chat.
141#[derive(Debug, Serialize, Deserialize)]
142pub struct ChatSession {
143    conversation_meta: ConversationMeta,
144    invocation_id: usize,
145    uuid: String,
146    ip: String,
147    style: ConversationStyle,
148}
149
150/// Response provided by bing.
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct NewBingResponseMessage {
153    /// text content of the response.
154    pub text: String,
155    /// suggested responses of the response.
156    pub suggested_responses: Vec<String>,
157    /// source attributions of the response.
158    pub source_attributions: Vec<String>,
159}
160
161#[derive(Serialize, Deserialize, Debug, Clone)]
162#[serde(rename_all = "camelCase")]
163struct NewBingRequestMessage {
164    author: &'static str,
165    input_method: &'static str,
166    text: String,
167    message_type: &'static str,
168}
169
170impl NewBingRequestMessage {
171    fn new(text: String) -> Self {
172        Self {
173            author: "user",
174            input_method: "Keyboard",
175            text,
176            message_type: "Chat",
177        }
178    }
179}
180
181#[derive(Serialize, Deserialize, Debug, Clone)]
182struct Participant {
183    id: String,
184}
185
186#[derive(Serialize, Deserialize, Debug, Clone)]
187#[serde(rename_all = "camelCase")]
188struct Argument {
189    source: &'static str,
190    options_sets: [&'static str; 10],
191    slice_ids: [&'static str; 3],
192    trace_id: String,
193    is_start_of_session: bool,
194    message: NewBingRequestMessage,
195    conversation_signature: String,
196    participant: Participant,
197    conversation_id: String,
198}
199
200impl Argument {
201    pub fn new(
202        conversation_meta: ConversationMeta,
203        style: ConversationStyle,
204        is_start_of_session: bool,
205        text: &str,
206    ) -> Self {
207        Self {
208            source: "cib",
209            options_sets: [
210                "nlu_direct_response_filter",
211                "deepleo",
212                "disable_emoji_spoken_text",
213                "responsible_ai_policy_235",
214                "enablemm",
215                style.into(),
216                "dtappid",
217                "cricinfo",
218                "cricinfov2",
219                "dv3sugg",
220            ],
221            slice_ids: ["222dtappid", "225cricinfo", "224locals0"],
222            trace_id: random_hex_string(32),
223            is_start_of_session,
224            message: NewBingRequestMessage::new(text.to_string()),
225            conversation_signature: conversation_meta.conversation_signature.to_string(),
226            participant: Participant {
227                id: conversation_meta.client_id.to_string(),
228            },
229            conversation_id: conversation_meta.conversation_id,
230        }
231    }
232}
233
234#[derive(Serialize, Deserialize, Debug, Clone)]
235#[serde(rename_all = "camelCase")]
236struct NewBingRequest {
237    arguments: [Argument; 1],
238    invocation_id: String,
239    target: &'static str,
240    #[serde(rename = "type")]
241    message_type: u8,
242}
243
244impl NewBingRequest {
245    fn new(
246        conversation_meta: ConversationMeta,
247        style: ConversationStyle,
248        invocation_id: usize,
249        text: &str,
250    ) -> Self {
251        Self {
252            arguments: [Argument::new(
253                conversation_meta,
254                style,
255                invocation_id == 0,
256                text,
257            )],
258            invocation_id: format!("{invocation_id}"),
259            target: "chat",
260            message_type: 4,
261        }
262    }
263}
264
265#[derive(Debug, Clone)]
266enum SignalRNewBingResponse {
267    Invocation(NewBingResponseMessage),
268    StreamItem(NewBingResponseMessage),
269    EndOfResponse,
270    Ping,
271    Unknown,
272}
273
274impl<'de> serde::Deserialize<'de> for SignalRNewBingResponse {
275    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> std::result::Result<Self, D::Error> {
276        let value = Value::deserialize(d)?;
277
278        Ok(match value.get("type").and_then(Value::as_u64).unwrap() {
279            1 => SignalRNewBingResponse::Invocation(deserialize_invocation(value).unwrap()),
280            2 => SignalRNewBingResponse::StreamItem(deserialize_newbing_response(value).unwrap()),
281            3 => SignalRNewBingResponse::EndOfResponse,
282            6 => SignalRNewBingResponse::Ping,
283            _ => SignalRNewBingResponse::Unknown,
284        })
285    }
286}
287
288fn deserialize_invocation(value: Value) -> Result<NewBingResponseMessage> {
289    let content = value["arguments"][0]["messages"][0]["text"]
290        .as_str()
291        .unwrap_or("")
292        .to_string();
293    let res = NewBingResponseMessage {
294        text: content,
295        suggested_responses: vec![],
296        source_attributions: vec![],
297    };
298    Ok(res)
299}
300
301fn deserialize_newbing_response(value: Value) -> Result<NewBingResponseMessage> {
302    let content = &value
303        .get("item")
304        .ok_or(ChatError::GetFieldError {
305            object_name: "newbing_response",
306            field_name: "item",
307        })?
308        .get("messages")
309        .ok_or(ChatError::GetFieldError {
310            object_name: "newbing_response.item",
311            field_name: "messages",
312        })?
313        .as_array()
314        .ok_or(ChatError::FieldTypeError {
315            object_name: "newbing_response.item",
316            field_name: "messages",
317            expected_type: "str",
318        })?
319        .iter()
320        .find(|msg| {
321            msg.get("messageType").is_none()
322                && msg
323                    .get("author")
324                    .and_then(|author| author.as_str())
325                    .map(|it| it == "bot")
326                    .unwrap_or(false)
327        })
328        .ok_or(ChatError::FieldTypeError {
329            object_name: "newbing_response.item",
330            field_name: "messages[{author == bot, messageType != null}]",
331            expected_type: "str",
332        })?;
333    let text = content
334        .get("text")
335        .ok_or(ChatError::GetFieldError {
336            object_name: "newbing_response.item.messages[{author == bot, messageType != null}]",
337            field_name: "text",
338        })?
339        .as_str()
340        .ok_or(ChatError::FieldTypeError {
341            object_name: "newbing_response.item.messages[{author == bot, messageType != null}]",
342            field_name: "text",
343            expected_type: "str",
344        })?
345        .to_string();
346    let suggested_responses = content
347        .get("suggestedResponses")
348        .ok_or(ChatError::GetFieldError {
349            object_name: "newbing_response.item.messages[{author == bot, messageType != null}]",
350            field_name: "suggestedResponses",
351        })?
352        .as_array()
353        .ok_or(ChatError::FieldTypeError {
354            object_name: "newbing_response.item.messages[{author == bot, messageType != null}]",
355            field_name: "suggestedResponses",
356            expected_type: "array",
357        })?
358        .iter()
359        .map(|suggested_response|{
360            Ok(suggested_response
361                .get("text")
362                .ok_or(ChatError::GetFieldError {
363                    object_name: "newbing_response.item.messages[{author == bot, messageType != null}].suggestedResponses",
364                    field_name: "text",
365                })?
366                .as_str()
367                .ok_or(ChatError::FieldTypeError {
368                    object_name: "newbing_response.item.messages[{author == bot, messageType != null}].suggestedResponses",
369                    field_name: "text",
370                    expected_type: "str",
371                })?
372                .to_string())
373        })
374        .collect::<Result<_>>()?;
375    let source_attributions = content
376        .get("sourceAttributions")
377        .ok_or(ChatError::GetFieldError {
378            object_name: "newbing_response.item.messages[{author == bot, messageType != null}]",
379            field_name: "sourceAttributions",
380        })?
381        .as_array()
382        .ok_or(ChatError::FieldTypeError {
383            object_name: "newbing_response.item.messages[{author == bot, messageType != null}]",
384            field_name: "sourceAttributions",
385            expected_type: "array",
386        })?
387        .iter()
388        .map(|it| {
389            Ok(it
390                .get("seeMoreUrl")
391                .ok_or(ChatError::GetFieldError {
392                    object_name:
393                        "newbing_response.item.messages[{author == bot, messageType != null}].sourceAttributions",
394                    field_name: "seeMoreUrl",
395                })?
396                .as_str()
397                .ok_or(ChatError::FieldTypeError {
398                    object_name:
399                        "newbing_response.item.messages[{author == bot, messageType != null}].sourceAttributions",
400                    field_name: "seeMoreUrl",
401                    expected_type: "str",
402                })?
403                .to_string())
404        })
405        .collect::<Result<_>>()?;
406    Ok(NewBingResponseMessage {
407        text,
408        suggested_responses,
409        source_attributions,
410    })
411}
412
413impl ChatSession {
414    pub fn new(
415        conversation_meta: ConversationMeta,
416        style: ConversationStyle,
417        invocation_id: usize,
418        uuid: String,
419        ip: String,
420    ) -> Self {
421        Self {
422            style,
423            conversation_meta,
424            invocation_id,
425            uuid,
426            ip,
427        }
428    }
429
430    /// Create a new [`ChatSession`] from cookies.
431    pub async fn create(
432        style: ConversationStyle,
433        cookies: &[CookieInFile],
434    ) -> conversation_meta::Result<Self> {
435        let uuid = Uuid::new_v4().hyphenated();
436        let uuid = uuid.encode_lower(&mut Uuid::encode_buffer()).to_string();
437        Ok(Self {
438            conversation_meta: ConversationMeta::create(cookies).await?,
439            invocation_id: 0,
440            uuid,
441            ip: random_forwarded_ip(),
442            style,
443        })
444    }
445
446    /// Create a new [`ChatStream`] for chatting with the bot in a [`Stream`].
447    pub async fn chat_stream(&mut self, text: &str) -> Result<ChatStream> {
448        let mut request = http::Request::builder()
449            .uri("wss://sydney.bing.com/sydney/ChatHub")
450            .body(())
451            .unwrap();
452        *(request.headers_mut()) = headers(&self.uuid, &self.ip);
453        let (mut ws_stream, _) = connect_async(request)
454            .await
455            .map_err(|_| ChatError::Network)?;
456        let mut handshake_message =
457            serde_json::to_vec(&json!({"protocol": "json", "version": 1})).unwrap();
458        handshake_message.push(DELIMITER);
459        let message = Message::Binary(handshake_message);
460        ws_stream
461            .send(message)
462            .await
463            .map_err(|_| ChatError::Network)?;
464        let _response = ws_stream
465            .next()
466            .await
467            .unwrap()
468            .map_err(|_| ChatError::Network)?;
469
470        let mut alive_message = serde_json::to_vec(&json!({"type": 6})).unwrap();
471        alive_message.push(DELIMITER);
472        let message = Message::Binary(alive_message);
473        ws_stream.send(message).await.unwrap();
474
475        let msg = NewBingRequest::new(
476            self.conversation_meta.clone(),
477            self.style,
478            self.invocation_id,
479            text,
480        );
481        let mut question_message = serde_json::to_vec(&msg).unwrap();
482        question_message.push(DELIMITER);
483        let message = Message::Binary(question_message);
484        ws_stream.send(message).await.unwrap();
485        self.invocation_id += 1;
486        let stream = chat_stream(ws_stream);
487        Ok(Box::pin(stream))
488    }
489
490    /// Send a message to the session, and return the response.
491    pub async fn send_message(&mut self, text: &str) -> Result<NewBingResponseMessage> {
492        let mut request = http::Request::builder()
493            .uri("wss://sydney.bing.com/sydney/ChatHub")
494            .body(())
495            .unwrap();
496        *(request.headers_mut()) = headers(&self.uuid, &self.ip);
497        let (ws_stream, _) = connect_async(request)
498            .await
499            .map_err(|_| ChatError::Network)?;
500        let (mut write, mut read) = ws_stream.split();
501        let mut handshake_message =
502            serde_json::to_vec(&json!({"protocol": "json", "version": 1})).unwrap();
503        handshake_message.push(DELIMITER);
504        let message = Message::Binary(handshake_message);
505        write.send(message).await.map_err(|_| ChatError::Network)?;
506
507        let _response = read.next().await.unwrap().map_err(|_| ChatError::Network)?;
508
509        let mut alive_message = serde_json::to_vec(&json!({"type": 6})).unwrap();
510        alive_message.push(DELIMITER);
511        let message = Message::Binary(alive_message);
512        write.send(message).await.unwrap();
513
514        let msg = NewBingRequest::new(
515            self.conversation_meta.clone(),
516            self.style,
517            self.invocation_id,
518            text,
519        );
520        let mut question_message = serde_json::to_vec(&msg).unwrap();
521        question_message.push(DELIMITER);
522        let message = Message::Binary(question_message);
523        write.send(message).await.unwrap();
524        self.invocation_id += 1;
525
526        while let Some(Ok(response)) = read.next().await {
527            if let Message::Text(content) = response {
528                let signal_r_packages = content
529                    .split('\u{1e}')
530                    .map(|it| it.trim())
531                    .filter(|it| !it.is_empty())
532                    .collect::<Vec<_>>();
533                for signal_r_package in signal_r_packages {
534                    let response: SignalRNewBingResponse = serde_json::from_str(signal_r_package)?;
535                    match response {
536                        SignalRNewBingResponse::StreamItem(message) => return Ok(message),
537                        SignalRNewBingResponse::EndOfResponse => {
538                            break;
539                        }
540                        SignalRNewBingResponse::Ping => {
541                            let mut alive_message =
542                                serde_json::to_vec(&json!({"type": 6})).unwrap();
543                            alive_message.push(DELIMITER);
544                            let message = Message::Binary(alive_message);
545                            write.send(message).await.map_err(|_| ChatError::Network)?;
546                        }
547                        _ => {}
548                    }
549                }
550            }
551        }
552        Err(ChatError::NoFullResponseFound)
553    }
554}
555
556#[derive(Error, Debug)]
557pub enum ChatError {
558    #[error("Failed to get field {field_name} from {object_name}")]
559    GetFieldError {
560        object_name: &'static str,
561        field_name: &'static str,
562    },
563    #[error("{object_name}.{field_name} should be of type {expected_type}")]
564    FieldTypeError {
565        object_name: &'static str,
566        field_name: &'static str,
567        expected_type: &'static str,
568    },
569    #[error("Failed to send chat request")]
570    Network,
571    #[error("Failed to parse chat response")]
572    ParseRespond(#[from] serde_json::Error),
573    #[error("No full response received")]
574    NoFullResponseFound,
575    #[error("No response received")]
576    NoResponse,
577}
578
579pub type Result<T> = std::result::Result<T, ChatError>;
580pub type ChatStream = Pin<Box<dyn Stream<Item = Result<NewBingResponseMessage>>>>;
581
582fn chat_stream(
583    wss: WebSocketStream<MaybeTlsStream<TcpStream>>,
584) -> impl Stream<Item = Result<NewBingResponseMessage>> {
585    try_stream! {
586        let (mut write, mut read) = wss.split();
587            'outer:while let Some(Ok(msg)) = read.next().await {
588                if let Message::Text(text) = msg {
589                    let packs = text
590                        .split('\u{1e}')
591                        .map(|it| it.trim())
592                        .filter(|it| !it.is_empty())
593                        .collect::<Vec<_>>();
594                    for pack in packs {
595                        let response = serde_json::from_str::<SignalRNewBingResponse>(pack);
596                        let response = match response {
597                            Ok(response) => match response {
598
599                                SignalRNewBingResponse::Invocation(res)
600                                | SignalRNewBingResponse::StreamItem(res) => Ok(res),
601                                SignalRNewBingResponse::Ping => {
602                                    let mut alive_message =
603                                        serde_json::to_vec(&json!({"type": 6})).unwrap();
604                                    alive_message.push(DELIMITER);
605                                    let message = Message::Binary(alive_message);
606                                    let e =
607                                        write.send(message).await.map_err(|_| ChatError::Network);
608                                    if let Err(e) = e {
609                                        Err(e)
610                                    } else {
611                                        continue;
612                                    }
613                                }
614                                SignalRNewBingResponse::EndOfResponse => break 'outer,
615                                _ => continue,
616                            },
617                            Err(_) => continue,
618                        };
619                        let response=response?;
620                        yield response;
621                    }
622                }
623            }
624    }
625}