actor_core_client/drivers/
sse.rs

1use anyhow::{Context, Result};
2use base64::prelude::*;
3use eventsource_client::{BoxStream, Client, ClientBuilder, ReconnectOptionsBuilder, SSE};
4use futures_util::StreamExt;
5use serde_json::Value;
6use std::sync::Arc;
7use tokio::sync::mpsc;
8use tokio::task::JoinHandle;
9use tracing::debug;
10
11use crate::encoding::EncodingKind;
12use crate::protocol::{ToClient, ToClientBody, ToServer};
13
14use super::{
15    build_conn_url, DriverHandle, DriverStopReason, MessageToClient, MessageToServer, TransportKind,
16};
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19struct ConnectionDetails {
20    id: String,
21    token: String,
22}
23
24pub(crate) async fn connect(
25    endpoint: String,
26    encoding_kind: EncodingKind,
27    parameters: &Option<Value>,
28) -> Result<(
29    DriverHandle,
30    mpsc::Receiver<MessageToClient>,
31    JoinHandle<DriverStopReason>,
32)> {
33    let url = build_conn_url(&endpoint, &TransportKind::Sse, encoding_kind, parameters)?;
34
35    let client = ClientBuilder::for_url(&url)?
36        .reconnect(ReconnectOptionsBuilder::new(false).build())
37        .build();
38
39    let (in_tx, in_rx) = mpsc::channel::<MessageToClient>(32);
40    let (out_tx, out_rx) = mpsc::channel::<MessageToServer>(32);
41
42    let task = tokio::spawn(start(client, endpoint, encoding_kind, in_tx, out_rx));
43
44    let handle = DriverHandle::new(out_tx, task.abort_handle());
45    Ok((handle, in_rx, task))
46}
47
48async fn start(
49    client: impl Client,
50    endpoint: String,
51    encoding_kind: EncodingKind,
52    in_tx: mpsc::Sender<MessageToClient>,
53    mut out_rx: mpsc::Receiver<MessageToServer>,
54) -> DriverStopReason {
55    let serialize = get_serializer(encoding_kind);
56    let deserialize = get_deserializer(encoding_kind);
57
58    let mut stream = client.stream();
59
60    let conn = match do_handshake(&mut stream, &deserialize, &in_tx).await {
61        Ok(conn) => conn,
62        Err(reason) => {
63            debug!("Failed to connect: {:?}", reason);
64            return reason;
65        }
66    };
67
68    loop {
69        tokio::select! {
70            msg = out_rx.recv() => {
71                let Some(msg) = msg else {
72                    return DriverStopReason::UserAborted;
73                };
74
75                let msg = match serialize(&msg) {
76                    Ok(msg) => msg,
77                    Err(e) => {
78                        debug!("Failed to serialize {:?} {:?}", msg, e);
79                        continue;
80                    }
81                };
82
83                // Add connection ID and token to the request URL
84                let request_url = format!(
85                    "{}/connections/{}/message?encoding={}&connectionToken={}",
86                    endpoint, conn.id, encoding_kind.as_str(), urlencoding::encode(&conn.token)
87                );
88
89                // Handle response
90                let resp = reqwest::Client::new()
91                    .post(request_url)
92                    .body(msg)
93                    .send()
94                    .await;
95
96                match resp {
97                    Ok(resp) => {
98                        if !resp.status().is_success() {
99                            debug!("Failed to send message: {:?}", resp);
100                        }
101
102                        if let Ok(t) = resp.text().await {
103                            debug!("Response: {:?}", t);
104                        }
105                    },
106                    Err(e) => {
107                        debug!("Failed to send message: {:?}", e);
108                    }
109                }
110            },
111            // Handle sse incoming
112            msg = stream.next() => {
113                let Some(msg) = msg else {
114                    debug!("Receiver dropped");
115                    return DriverStopReason::ServerDisconnect;
116                };
117
118                match msg {
119                    Ok(msg) => match msg {
120                        SSE::Comment(comment) => debug!("Sse comment: {}", comment),
121                        SSE::Connected(_) => debug!("warning: received sse connection past-handshake"),
122                        SSE::Event(event) => {
123                            // println!("POST INIT event coming in: {:?}", event.data);
124                            let msg = match deserialize(&event.data) {
125                                Ok(msg) => msg,
126                                Err(e) => {
127                                    debug!("Failed to deserialize {:?} {:?}", event, e);
128                                    continue;
129                                }
130                            };
131
132                            if let Err(e) = in_tx.send(Arc::new(msg)).await {
133                                debug!("Receiver in_rx dropped {:?}", e);
134                                return DriverStopReason::UserAborted;
135                            }
136                        },
137                    }
138                    Err(e) => {
139                        debug!("Sse error: {}", e);
140                        return DriverStopReason::ServerError;
141                    }
142                }
143            }
144        }
145    }
146}
147
148async fn do_handshake(
149    stream: &mut BoxStream<eventsource_client::Result<SSE>>,
150    deserialize: &impl Fn(&str) -> Result<ToClient>,
151    in_tx: &mpsc::Sender<MessageToClient>,
152) -> Result<ConnectionDetails, DriverStopReason> {
153    loop {
154        tokio::select! {
155            // Handle sse incoming
156            msg = stream.next() => {
157                let Some(msg) = msg else {
158                    debug!("Receiver dropped");
159                    return Err(DriverStopReason::ServerDisconnect);
160                };
161
162                match msg {
163                    Ok(msg) => match msg {
164                        SSE::Comment(comment) => debug!("Sse comment {:?}", comment),
165                        SSE::Connected(_) => debug!("Connected Sse"),
166                        SSE::Event(event) => {
167                            let msg = match deserialize(&event.data) {
168                                Ok(msg) => msg,
169                                Err(e) => {
170                                    debug!("Failed to deserialize {:?} {:?}", event, e);
171                                    continue;
172                                }
173                            };
174
175                            let msg = Arc::new(msg);
176
177                            if let Err(e) = in_tx.send(msg.clone()).await {
178                                debug!("Receiver in_rx dropped {:?}", e);
179                                return Err(DriverStopReason::UserAborted);
180                            }
181
182                            // Wait until we get an Init packet
183                            let ToClientBody::Init { i } = &msg.b else {
184                                continue;
185                            };
186
187                            // Mark handshake complete
188                            let conn_id = &i.ci;
189                            let conn_token = &i.ct;
190
191                            return Ok(ConnectionDetails {
192                                id: conn_id.clone(),
193                                token: conn_token.clone()
194                            })
195                        },
196                    }
197                    Err(e) => {
198                        eprintln!("Sse error: {}", e);
199                        return Err(DriverStopReason::ServerError);
200                    }
201                }
202            }
203        }
204    }
205}
206
207fn get_serializer(encoding_kind: EncodingKind) -> impl Fn(&ToServer) -> Result<Vec<u8>> {
208    encoding_kind.get_default_serializer()
209}
210
211fn get_deserializer(encoding_kind: EncodingKind) -> impl Fn(&str) -> Result<ToClient> {
212    match encoding_kind {
213        EncodingKind::Json => json_deserialize,
214        EncodingKind::Cbor => cbor_deserialize,
215    }
216}
217
218fn json_deserialize(value: &str) -> Result<ToClient> {
219    let msg = serde_json::from_str::<ToClient>(value)?;
220
221    Ok(msg)
222}
223
224fn cbor_deserialize(msg: &str) -> Result<ToClient> {
225    let msg = BASE64_STANDARD
226        .decode(msg.as_bytes())
227        .context("base64 failure:")?;
228    let msg = serde_cbor::from_slice::<ToClient>(&msg).context("serde failure:")?;
229
230    Ok(msg)
231}