use crate::event::AISdkEvent;
use crate::stream::AISdkStreamBuilder;
use async_stream::stream;
use futures::{Stream, StreamExt};
use rig::agent::MultiTurnStreamItem;
use rig::streaming::{StreamedAssistantContent, StreamedUserContent, ToolCallDeltaContent};
use std::collections::{HashMap};
use std::fmt::Display;
pub const FRONTEND_TOOL_CANCEL_REASON: &str = "__FRONTEND_TOOL__";
pub fn adapt_rig_stream_sse<S, R, E>(
rig_stream: S,
) -> impl Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>
where
S: Stream<Item = Result<MultiTurnStreamItem<R>, E>> + Unpin,
E: Display,
{
adapt_rig_stream(rig_stream).map(|event| {
Ok(axum::response::sse::Event::from(
event.unwrap_or_else(|e| AISdkEvent::error(e.to_string())),
))
})
}
pub fn adapt_rig_stream<S, R, E>(rig_stream: S) -> impl Stream<Item = Result<AISdkEvent, E>>
where
S: Stream<Item = Result<MultiTurnStreamItem<R>, E>> + Unpin,
E: Display,
{
stream! {
let mut events = AISdkStreamBuilder::new();
let mut tool_names: HashMap<String, String> = HashMap::new();
let mut rig_stream = rig_stream;
yield Ok(events.start());
while let Some(msg) = rig_stream.next().await {
let msg = match msg {
Ok(m) => m,
Err(e) => {
let err_str = e.to_string();
if err_str.contains(FRONTEND_TOOL_CANCEL_REASON){
break;
}
yield Ok(AISdkEvent::error(err_str));
break;
}
};
for event in convert_stream_item(&mut events, &mut tool_names, msg) {
yield Ok(event);
}
}
yield Ok(events.finish());
yield Ok(events.done());
}
}
fn convert_stream_item<R>(
events: &mut AISdkStreamBuilder,
tool_names: &mut HashMap<String, String>,
item: MultiTurnStreamItem<R>,
) -> Vec<AISdkEvent> {
let mut result = Vec::new();
match item {
MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult(tool_result)) => {
let tool_call_id = tool_result
.call_id
.as_ref()
.unwrap_or(&tool_result.id)
.clone();
result.push(AISdkEvent::ToolOutputAvailable {
tool_call_id,
output: serde_json::to_value(&tool_result.content).unwrap_or_default(),
provider_executed: None,
dynamic: None,
preliminary: None,
});
}
MultiTurnStreamItem::StreamAssistantItem(assistant) => match assistant {
StreamedAssistantContent::Text(text) => {
if let Some(reasoning_end) = events.reasoning_end() {
result.push(reasoning_end);
}
if let Some(text_start) = events.text_start() {
result.push(text_start);
}
if let Some(text_delta) = events.text_delta(text.text) {
result.push(text_delta);
}
}
StreamedAssistantContent::ToolCall(tool_call) => {
let tool_call_id = tool_call.call_id.as_ref().unwrap_or(&tool_call.id).clone();
result.push(AISdkEvent::ToolInputAvailable {
tool_call_id,
tool_name: tool_call.function.name,
input: tool_call.function.arguments,
provider_executed: None,
provider_metadata: None,
dynamic: None,
});
}
StreamedAssistantContent::ToolCallDelta { id, content } => match content {
ToolCallDeltaContent::Name(name) => {
tool_names.insert(id.clone(), name.clone());
result.push(AISdkEvent::ToolInputStart {
tool_call_id: id,
tool_name: name,
provider_executed: None,
provider_metadata: None,
dynamic: None,
});
}
ToolCallDeltaContent::Delta(delta) => {
result.push(AISdkEvent::ToolInputDelta {
tool_call_id: id,
delta,
});
}
},
StreamedAssistantContent::Reasoning(reasoning) => {
result.push(events.reasoning_start(reasoning.id));
for item in &reasoning.reasoning {
if let Some(delta) = events.reasoning_delta(item, None) {
result.push(delta);
}
}
}
StreamedAssistantContent::ReasoningDelta { reasoning, id } => {
if let Some(reasoning_delta) = events.reasoning_delta(reasoning, id) {
result.push(reasoning_delta);
}
}
StreamedAssistantContent::Final(_) => {
if let Some(text_end) = events.text_end() {
result.push(text_end);
}
}
},
MultiTurnStreamItem::FinalResponse(final_response) => {
result.push(AISdkEvent::custom_data("usage", final_response.usage()));
}
_ => {}
}
result
}