1use crate::context::{Context, Receiver};
75use crate::error::{Mecha10Error, Result};
76use crate::messages::Message;
77use serde::{de::DeserializeOwned, Deserialize, Serialize};
78use std::future::Future;
79use std::marker::PhantomData;
80use std::sync::Arc;
81use tokio::sync::Mutex;
82use tracing::{debug, error, info, warn};
83use uuid::Uuid;
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct StreamMessage<T> {
92 pub stream_id: String,
94
95 pub sequence: u64,
97
98 pub payload: StreamPayload<T>,
100
101 pub timestamp: u64,
103}
104
105impl<T: Message> Message for StreamMessage<T> {}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub enum StreamPayload<T> {
110 Data(T),
112
113 Start,
115
116 End,
118
119 Error(StreamError),
121
122 Cancel,
124
125 Heartbeat,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct StreamError {
132 pub message: String,
133 pub code: Option<String>,
134}
135
136pub struct StreamReceiver<T> {
144 stream_id: String,
145 receiver: Receiver<StreamMessage<T>>,
146 sequence: u64,
147 ended: bool,
148}
149
150impl<T: Message + DeserializeOwned + Send + 'static> StreamReceiver<T> {
151 pub(crate) fn new(stream_id: String, receiver: Receiver<StreamMessage<T>>) -> Self {
153 Self {
154 stream_id,
155 receiver,
156 sequence: 0,
157 ended: false,
158 }
159 }
160
161 pub async fn recv(&mut self) -> Option<T> {
167 if self.ended {
168 return None;
169 }
170
171 loop {
172 let msg = self.receiver.recv().await?;
173
174 if msg.stream_id != self.stream_id {
176 continue;
177 }
178
179 if msg.sequence < self.sequence {
181 warn!(
182 "Out of order message in stream {} (expected {}, got {})",
183 self.stream_id, self.sequence, msg.sequence
184 );
185 }
186 self.sequence = msg.sequence + 1;
187
188 match msg.payload {
189 StreamPayload::Data(data) => {
190 debug!("Stream {} received message #{}", self.stream_id, msg.sequence);
191 return Some(data);
192 }
193 StreamPayload::End => {
194 info!("Stream {} ended normally", self.stream_id);
195 self.ended = true;
196 return None;
197 }
198 StreamPayload::Error(err) => {
199 error!("Stream {} error: {}", self.stream_id, err.message);
200 self.ended = true;
201 return None;
202 }
203 StreamPayload::Cancel => {
204 info!("Stream {} cancelled", self.stream_id);
205 self.ended = true;
206 return None;
207 }
208 StreamPayload::Start => {
209 debug!("Stream {} started", self.stream_id);
210 }
212 StreamPayload::Heartbeat => {
213 debug!("Stream {} heartbeat", self.stream_id);
214 }
216 }
217 }
218 }
219
220 pub fn is_ended(&self) -> bool {
222 self.ended
223 }
224
225 pub fn stream_id(&self) -> &str {
227 &self.stream_id
228 }
229}
230
231#[derive(Clone)]
239pub struct StreamSender<T: Message + Serialize + Clone> {
240 stream_id: String,
241 topic: String,
242 ctx: Arc<Context>,
243 sequence: Arc<Mutex<u64>>,
244 ended: Arc<Mutex<bool>>,
245 _phantom: PhantomData<T>,
246}
247
248impl<T: Message + Serialize + Clone> StreamSender<T> {
249 pub(crate) fn new(stream_id: String, topic: String, ctx: Arc<Context>) -> Self {
251 Self {
252 stream_id,
253 topic,
254 ctx,
255 sequence: Arc::new(Mutex::new(0)),
256 ended: Arc::new(Mutex::new(false)),
257 _phantom: PhantomData,
258 }
259 }
260
261 pub async fn send(&self, data: &T) -> Result<()> {
263 if *self.ended.lock().await {
264 return Err(Mecha10Error::MessagingError {
265 message: "Stream already ended".to_string(),
266 suggestion: "Cannot send after stream end".to_string(),
267 });
268 }
269
270 let mut seq = self.sequence.lock().await;
271 let message = StreamMessage {
272 stream_id: self.stream_id.clone(),
273 sequence: *seq,
274 payload: StreamPayload::Data(data.clone()),
275 timestamp: now_micros(),
276 };
277
278 *seq += 1;
279 drop(seq);
280
281 self.ctx.publish_raw(&self.topic, &message).await?;
282 debug!("Stream {} sent message #{}", self.stream_id, message.sequence);
283
284 Ok(())
285 }
286
287 pub async fn end(&self) -> Result<()> {
289 let mut ended = self.ended.lock().await;
290 if *ended {
291 return Ok(()); }
293
294 let seq = self.sequence.lock().await;
295 let message = StreamMessage {
296 stream_id: self.stream_id.clone(),
297 sequence: *seq,
298 payload: StreamPayload::<T>::End,
299 timestamp: now_micros(),
300 };
301
302 self.ctx.publish_raw(&self.topic, &message).await?;
303 *ended = true;
304
305 info!("Stream {} ended", self.stream_id);
306 Ok(())
307 }
308
309 pub async fn error(&self, error_message: &str) -> Result<()> {
311 let mut ended = self.ended.lock().await;
312 if *ended {
313 return Ok(());
314 }
315
316 let seq = self.sequence.lock().await;
317 let message = StreamMessage {
318 stream_id: self.stream_id.clone(),
319 sequence: *seq,
320 payload: StreamPayload::<T>::Error(StreamError {
321 message: error_message.to_string(),
322 code: None,
323 }),
324 timestamp: now_micros(),
325 };
326
327 self.ctx.publish_raw(&self.topic, &message).await?;
328 *ended = true;
329
330 error!("Stream {} error: {}", self.stream_id, error_message);
331 Ok(())
332 }
333
334 pub async fn heartbeat(&self) -> Result<()> {
336 if *self.ended.lock().await {
337 return Ok(());
338 }
339
340 let seq = self.sequence.lock().await;
341 let message = StreamMessage {
342 stream_id: self.stream_id.clone(),
343 sequence: *seq,
344 payload: StreamPayload::<T>::Heartbeat,
345 timestamp: now_micros(),
346 };
347
348 self.ctx.publish_raw(&self.topic, &message).await?;
349 debug!("Stream {} heartbeat", self.stream_id);
350
351 Ok(())
352 }
353
354 pub async fn is_ended(&self) -> bool {
356 *self.ended.lock().await
357 }
358
359 pub fn stream_id(&self) -> &str {
361 &self.stream_id
362 }
363}
364
365impl<T: Message + Serialize + Clone> Drop for StreamSender<T> {
367 fn drop(&mut self) {
368 debug!("StreamSender dropped for stream {}", self.stream_id);
371 }
372}
373
374pub trait StreamingRpcExt {
380 fn stream_request<Req, Resp>(
394 &self,
395 topic: &str,
396 request: &Req,
397 ) -> impl Future<Output = Result<StreamReceiver<Resp>>> + Send
398 where
399 Req: Message + Serialize + Clone,
400 Resp: Message + DeserializeOwned + Send + 'static;
401
402 fn stream_respond<Req, Resp, F, Fut>(&self, topic: &str, handler: F) -> impl Future<Output = Result<()>> + Send
416 where
417 Req: Message + DeserializeOwned + Send + 'static,
418 Resp: Message + Serialize + Clone + Send + 'static,
419 F: Fn(Req, StreamSender<Resp>) -> Fut + Send + Sync + 'static,
420 Fut: Future<Output = Result<()>> + Send + 'static;
421
422 fn bidirectional_stream<T>(
438 &self,
439 topic: &str,
440 ) -> impl Future<Output = Result<(StreamSender<T>, StreamReceiver<T>)>> + Send
441 where
442 T: Message + Serialize + DeserializeOwned + Clone + Send + 'static;
443}
444
445impl StreamingRpcExt for Context {
446 async fn stream_request<Req, Resp>(&self, topic: &str, request: &Req) -> Result<StreamReceiver<Resp>>
447 where
448 Req: Message + Serialize + Clone,
449 Resp: Message + DeserializeOwned + Send + 'static,
450 {
451 let stream_id = Uuid::new_v4().to_string();
452 let request_topic = format!("{}/stream/request", topic);
453 let response_topic = format!("{}/stream/response", topic);
454
455 let receiver = self.subscribe_raw::<StreamMessage<Resp>>(&response_topic).await?;
457
458 let start_msg = StreamMessage {
460 stream_id: stream_id.clone(),
461 sequence: 0,
462 payload: StreamPayload::Data(request.clone()),
463 timestamp: now_micros(),
464 };
465
466 self.publish_raw(&request_topic, &start_msg).await?;
467
468 info!("Stream request started: {} (stream_id: {})", topic, stream_id);
469
470 Ok(StreamReceiver::new(stream_id, receiver))
471 }
472
473 async fn stream_respond<Req, Resp, F, Fut>(&self, topic: &str, handler: F) -> Result<()>
474 where
475 Req: Message + DeserializeOwned + Send + 'static,
476 Resp: Message + Serialize + Clone + Send + 'static,
477 F: Fn(Req, StreamSender<Resp>) -> Fut + Send + Sync + 'static,
478 Fut: Future<Output = Result<()>> + Send + 'static,
479 {
480 let request_topic = format!("{}/stream/request", topic);
481 let response_topic = format!("{}/stream/response", topic);
482
483 let mut requests = self.subscribe_raw::<StreamMessage<Req>>(&request_topic).await?;
484 let ctx = Arc::new(self.clone());
485 let handler = Arc::new(handler);
486
487 info!("Stream responder registered for: {}", topic);
488
489 tokio::spawn(async move {
490 while let Some(msg) = requests.recv().await {
491 match msg.payload {
492 StreamPayload::Data(request) => {
493 let stream_id = msg.stream_id.clone();
494 let sender = StreamSender::new(stream_id.clone(), response_topic.clone(), Arc::clone(&ctx));
495
496 let handler = Arc::clone(&handler);
497
498 tokio::spawn(async move {
500 debug!("Handling stream request: {}", stream_id);
501
502 if let Err(e) = handler(request, sender.clone()).await {
503 error!("Stream handler error ({}): {}", stream_id, e);
504 let _ = sender.error(&e.to_string()).await;
505 } else {
506 let _ = sender.end().await;
507 }
508 });
509 }
510 StreamPayload::Cancel => {
511 debug!("Stream cancelled: {}", msg.stream_id);
512 }
513 _ => {
514 }
516 }
517 }
518 });
519
520 Ok(())
521 }
522
523 async fn bidirectional_stream<T>(&self, topic: &str) -> Result<(StreamSender<T>, StreamReceiver<T>)>
524 where
525 T: Message + Serialize + DeserializeOwned + Clone + Send + 'static,
526 {
527 let stream_id = Uuid::new_v4().to_string();
528 let client_topic = format!("{}/bidir/client", topic);
529 let server_topic = format!("{}/bidir/server", topic);
530
531 let receiver = self.subscribe_raw::<StreamMessage<T>>(&server_topic).await?;
533
534 let sender = StreamSender::new(stream_id.clone(), client_topic, Arc::new(self.clone()));
536
537 info!("Bidirectional stream created: {} (stream_id: {})", topic, stream_id);
538
539 Ok((sender, StreamReceiver::new(stream_id, receiver)))
540 }
541}
542
543fn now_micros() -> u64 {
548 use std::time::{SystemTime, UNIX_EPOCH};
549 SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_micros() as u64
550}