1use clau_core::{Error, Result, Message, StreamFormat};
2use futures::{Stream, StreamExt};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio::sync::mpsc;
6use tracing::error;
7
8pub struct MessageStream {
9 receiver: mpsc::Receiver<Result<Message>>,
10}
11
12impl MessageStream {
13 pub fn new(receiver: mpsc::Receiver<Result<Message>>, _format: StreamFormat) -> Self {
14 Self { receiver }
15 }
16
17 pub async fn collect_full_response(mut self) -> Result<String> {
18 let mut response = String::new();
19
20 while let Some(result) = self.next().await {
21 match result? {
22 Message::Assistant { content, .. } => {
23 response.push_str(&content);
24 }
25 Message::Result { .. } => {
26 break;
28 }
29 _ => {}
30 }
31 }
32
33 Ok(response)
34 }
35}
36
37impl Stream for MessageStream {
38 type Item = Result<Message>;
39
40 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
41 self.receiver.poll_recv(cx)
42 }
43}
44
45pub struct MessageParser {
46 format: StreamFormat,
47}
48
49impl MessageParser {
50 pub fn new(format: StreamFormat) -> Self {
51 Self { format }
52 }
53
54 pub fn parse_line(&self, line: &str) -> Result<Option<Message>> {
55 match self.format {
56 StreamFormat::Text => {
57 Ok(None)
59 }
60 StreamFormat::Json | StreamFormat::StreamJson => {
61 if line.trim().is_empty() {
62 return Ok(None);
63 }
64
65 match serde_json::from_str::<Message>(line) {
66 Ok(message) => Ok(Some(message)),
67 Err(e) => {
68 error!("Failed to parse message: {}, line: {}", e, line);
69 Err(Error::SerializationError(e))
70 }
71 }
72 }
73 }
74 }
75
76 pub fn parse_text_response(&self, text: &str) -> Message {
77 Message::Assistant {
79 content: text.to_string(),
80 meta: clau_core::MessageMeta {
81 session_id: "text-response".to_string(),
82 timestamp: Some(std::time::SystemTime::now()),
83 cost_usd: None,
84 duration_ms: None,
85 tokens_used: None,
86 },
87 }
88 }
89}