use crate::agent::MCPError;
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::Debug;
use std::pin::Pin;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamingToken {
pub content: String,
pub is_finish: bool,
pub metadata: Option<Value>,
}
pub type TokenStream = Pin<Box<dyn Stream<Item = Result<StreamingToken, MCPError>> + Send>>;
pub fn create_token_stream(
receiver: mpsc::Receiver<Result<StreamingToken, MCPError>>,
) -> TokenStream {
let stream = ReceiverStream::new(receiver);
Box::pin(stream)
}
pub async fn process_json_stream<S, T>(stream: S) -> Result<TokenStream, MCPError>
where
S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
T: for<'de> Deserialize<'de> + Send + 'static + Debug,
{
let (tx, rx) = mpsc::channel(100);
tokio::spawn(async move {
let mut stream = Box::pin(stream);
let mut buffer = String::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk);
buffer.push_str(&chunk_str);
while let Some(pos) = buffer.find('\n') {
let line = buffer[..pos].trim().to_string();
buffer = buffer[pos + 1..].to_string();
if line.is_empty() || line == "data: [DONE]" {
continue;
}
let json_str = line.strip_prefix("data: ").unwrap_or(&line);
match serde_json::from_str::<T>(json_str) {
Ok(parsed) => {
let token = StreamingToken {
content: format!("{:?}", parsed),
is_finish: false,
metadata: None,
};
if tx.send(Ok(token)).await.is_err() {
break;
}
}
Err(e) => {
let _ = tx
.send(Err(MCPError::InternalAgentError(format!(
"Erro ao desserializar: {}",
e
))))
.await;
}
}
}
}
Err(e) => {
let _ = tx
.send(Err(MCPError::InternalAgentError(format!(
"Erro de rede: {}",
e
))))
.await;
break;
}
}
}
let _ = tx
.send(Ok(StreamingToken {
content: String::new(),
is_finish: true,
metadata: None,
}))
.await;
});
Ok(create_token_stream(rx))
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream;
#[tokio::test]
async fn test_create_token_stream() {
let (tx, rx) = mpsc::channel(10);
tx.send(Ok(StreamingToken {
content: "Token 1".to_string(),
is_finish: false,
metadata: None,
}))
.await
.unwrap();
tx.send(Ok(StreamingToken {
content: "Token 2".to_string(),
is_finish: false,
metadata: None,
}))
.await
.unwrap();
tx.send(Ok(StreamingToken {
content: "".to_string(),
is_finish: true,
metadata: None,
}))
.await
.unwrap();
let mut token_stream = create_token_stream(rx);
let mut collected_content = Vec::new();
let mut saw_finish = false;
while let Some(token_result) = token_stream.next().await {
let token = token_result.unwrap();
if token.is_finish {
saw_finish = true;
break;
}
collected_content.push(token.content);
}
assert_eq!(collected_content, vec!["Token 1", "Token 2"]);
assert!(saw_finish);
}
#[derive(Deserialize, Debug)]
struct TestResponse {
#[allow(dead_code)]
text: String,
}
#[tokio::test]
async fn test_process_json_stream() {
let chunks = vec![
Ok(bytes::Bytes::from(r#"{"text":"Parte 1"}"#)),
Ok(bytes::Bytes::from("\n")),
Ok(bytes::Bytes::from(r#"{"text":"Parte 2"}"#)),
Ok(bytes::Bytes::from("\n")),
Ok(bytes::Bytes::from("data: [DONE]\n")),
];
let mock_stream = stream::iter(chunks);
let mut token_stream = process_json_stream::<_, TestResponse>(mock_stream)
.await
.unwrap();
let mut tokens = Vec::new();
while let Some(token_result) = token_stream.next().await {
let token = token_result.unwrap();
if token.is_finish {
break;
}
tokens.push(token.content);
}
assert_eq!(tokens.len(), 2);
assert!(tokens[0].contains("Parte 1"));
assert!(tokens[1].contains("Parte 2"));
}
}