actor_core_client/drivers/
sse.rs1use anyhow::{Context, Result};
2use base64::prelude::*;
3use eventsource_client::{BoxStream, Client, ClientBuilder, ReconnectOptionsBuilder, SSE};
4use futures_util::StreamExt;
5use serde_json::Value;
6use std::sync::Arc;
7use tokio::sync::mpsc;
8use tokio::task::JoinHandle;
9use tracing::debug;
10
11use crate::encoding::EncodingKind;
12use crate::protocol::{ToClient, ToClientBody, ToServer};
13
14use super::{
15 build_conn_url, DriverHandle, DriverStopReason, MessageToClient, MessageToServer, TransportKind,
16};
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19struct ConnectionDetails {
20 id: String,
21 token: String,
22}
23
24pub(crate) async fn connect(
25 endpoint: String,
26 encoding_kind: EncodingKind,
27 parameters: &Option<Value>,
28) -> Result<(
29 DriverHandle,
30 mpsc::Receiver<MessageToClient>,
31 JoinHandle<DriverStopReason>,
32)> {
33 let url = build_conn_url(&endpoint, &TransportKind::Sse, encoding_kind, parameters)?;
34
35 let client = ClientBuilder::for_url(&url)?
36 .reconnect(ReconnectOptionsBuilder::new(false).build())
37 .build();
38
39 let (in_tx, in_rx) = mpsc::channel::<MessageToClient>(32);
40 let (out_tx, out_rx) = mpsc::channel::<MessageToServer>(32);
41
42 let task = tokio::spawn(start(client, endpoint, encoding_kind, in_tx, out_rx));
43
44 let handle = DriverHandle::new(out_tx, task.abort_handle());
45 Ok((handle, in_rx, task))
46}
47
48async fn start(
49 client: impl Client,
50 endpoint: String,
51 encoding_kind: EncodingKind,
52 in_tx: mpsc::Sender<MessageToClient>,
53 mut out_rx: mpsc::Receiver<MessageToServer>,
54) -> DriverStopReason {
55 let serialize = get_serializer(encoding_kind);
56 let deserialize = get_deserializer(encoding_kind);
57
58 let mut stream = client.stream();
59
60 let conn = match do_handshake(&mut stream, &deserialize, &in_tx).await {
61 Ok(conn) => conn,
62 Err(reason) => {
63 debug!("Failed to connect: {:?}", reason);
64 return reason;
65 }
66 };
67
68 loop {
69 tokio::select! {
70 msg = out_rx.recv() => {
71 let Some(msg) = msg else {
72 return DriverStopReason::UserAborted;
73 };
74
75 let msg = match serialize(&msg) {
76 Ok(msg) => msg,
77 Err(e) => {
78 debug!("Failed to serialize {:?} {:?}", msg, e);
79 continue;
80 }
81 };
82
83 let request_url = format!(
85 "{}/connections/{}/message?encoding={}&connectionToken={}",
86 endpoint, conn.id, encoding_kind.as_str(), urlencoding::encode(&conn.token)
87 );
88
89 let resp = reqwest::Client::new()
91 .post(request_url)
92 .body(msg)
93 .send()
94 .await;
95
96 match resp {
97 Ok(resp) => {
98 if !resp.status().is_success() {
99 debug!("Failed to send message: {:?}", resp);
100 }
101
102 if let Ok(t) = resp.text().await {
103 debug!("Response: {:?}", t);
104 }
105 },
106 Err(e) => {
107 debug!("Failed to send message: {:?}", e);
108 }
109 }
110 },
111 msg = stream.next() => {
113 let Some(msg) = msg else {
114 debug!("Receiver dropped");
115 return DriverStopReason::ServerDisconnect;
116 };
117
118 match msg {
119 Ok(msg) => match msg {
120 SSE::Comment(comment) => debug!("Sse comment: {}", comment),
121 SSE::Connected(_) => debug!("warning: received sse connection past-handshake"),
122 SSE::Event(event) => {
123 let msg = match deserialize(&event.data) {
125 Ok(msg) => msg,
126 Err(e) => {
127 debug!("Failed to deserialize {:?} {:?}", event, e);
128 continue;
129 }
130 };
131
132 if let Err(e) = in_tx.send(Arc::new(msg)).await {
133 debug!("Receiver in_rx dropped {:?}", e);
134 return DriverStopReason::UserAborted;
135 }
136 },
137 }
138 Err(e) => {
139 debug!("Sse error: {}", e);
140 return DriverStopReason::ServerError;
141 }
142 }
143 }
144 }
145 }
146}
147
148async fn do_handshake(
149 stream: &mut BoxStream<eventsource_client::Result<SSE>>,
150 deserialize: &impl Fn(&str) -> Result<ToClient>,
151 in_tx: &mpsc::Sender<MessageToClient>,
152) -> Result<ConnectionDetails, DriverStopReason> {
153 loop {
154 tokio::select! {
155 msg = stream.next() => {
157 let Some(msg) = msg else {
158 debug!("Receiver dropped");
159 return Err(DriverStopReason::ServerDisconnect);
160 };
161
162 match msg {
163 Ok(msg) => match msg {
164 SSE::Comment(comment) => debug!("Sse comment {:?}", comment),
165 SSE::Connected(_) => debug!("Connected Sse"),
166 SSE::Event(event) => {
167 let msg = match deserialize(&event.data) {
168 Ok(msg) => msg,
169 Err(e) => {
170 debug!("Failed to deserialize {:?} {:?}", event, e);
171 continue;
172 }
173 };
174
175 let msg = Arc::new(msg);
176
177 if let Err(e) = in_tx.send(msg.clone()).await {
178 debug!("Receiver in_rx dropped {:?}", e);
179 return Err(DriverStopReason::UserAborted);
180 }
181
182 let ToClientBody::Init { i } = &msg.b else {
184 continue;
185 };
186
187 let conn_id = &i.ci;
189 let conn_token = &i.ct;
190
191 return Ok(ConnectionDetails {
192 id: conn_id.clone(),
193 token: conn_token.clone()
194 })
195 },
196 }
197 Err(e) => {
198 eprintln!("Sse error: {}", e);
199 return Err(DriverStopReason::ServerError);
200 }
201 }
202 }
203 }
204 }
205}
206
207fn get_serializer(encoding_kind: EncodingKind) -> impl Fn(&ToServer) -> Result<Vec<u8>> {
208 encoding_kind.get_default_serializer()
209}
210
211fn get_deserializer(encoding_kind: EncodingKind) -> impl Fn(&str) -> Result<ToClient> {
212 match encoding_kind {
213 EncodingKind::Json => json_deserialize,
214 EncodingKind::Cbor => cbor_deserialize,
215 }
216}
217
218fn json_deserialize(value: &str) -> Result<ToClient> {
219 let msg = serde_json::from_str::<ToClient>(value)?;
220
221 Ok(msg)
222}
223
224fn cbor_deserialize(msg: &str) -> Result<ToClient> {
225 let msg = BASE64_STANDARD
226 .decode(msg.as_bytes())
227 .context("base64 failure:")?;
228 let msg = serde_cbor::from_slice::<ToClient>(&msg).context("serde failure:")?;
229
230 Ok(msg)
231}