use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_util::sync::CancellationToken;
use crate::tool::ToolCall;
use crate::types::{AgentMessage, AssistantContent, StopReason};
#[derive(Debug, Clone)]
pub struct StreamRequest {
pub system_prompt: String,
pub messages: Vec<AgentMessage>,
pub tools: Vec<ToolSchema>,
pub temperature: Option<f32>,
pub max_output_tokens: Option<u32>,
pub reasoning: ReasoningEffort,
#[allow(clippy::struct_field_names)]
pub provider_extras: Value,
pub force_tool_call: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ReasoningEffort {
None,
#[default]
Minimal,
Low,
Medium,
High,
XHigh,
}
impl ReasoningEffort {
pub fn as_wire(self) -> &'static str {
match self {
ReasoningEffort::None => "none",
ReasoningEffort::Minimal => "minimal",
ReasoningEffort::Low => "low",
ReasoningEffort::Medium => "medium",
ReasoningEffort::High => "high",
ReasoningEffort::XHigh => "xhigh",
}
}
pub fn from_wire(s: &str) -> Option<Self> {
match s.trim().to_ascii_lowercase().as_str() {
"none" => Some(ReasoningEffort::None),
"minimal" => Some(ReasoningEffort::Minimal),
"low" => Some(ReasoningEffort::Low),
"medium" => Some(ReasoningEffort::Medium),
"high" => Some(ReasoningEffort::High),
"xhigh" => Some(ReasoningEffort::XHigh),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSchema {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
Start { partial: AgentMessage },
Chunk(AssistantStreamChunk),
Done { message: AgentMessage },
Error {
partial: AgentMessage,
kind: StreamErrorKind,
message: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum AssistantStreamChunk {
Text { delta: String },
Thinking { delta: String },
Reasoning { delta: String },
ReasoningDetails { delta: Vec<Value> },
ToolCallDelta {
index: usize,
id_delta: Option<String>,
name_delta: Option<String>,
arguments_delta: Option<String>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StreamErrorKind {
Transient,
ProviderRateLimited,
ZeroOutputTransport,
Fatal,
Empty,
Aborted,
ContextOverflow,
}
#[derive(Debug, Clone)]
pub struct StreamResponse {
pub content: AssistantContent,
pub stop_reason: StopReason,
pub error_message: Option<String>,
pub tool_calls: Vec<ToolCall>,
}
#[async_trait]
pub trait StreamFn: Send + Sync + 'static {
async fn stream(
&self,
request: StreamRequest,
signal: CancellationToken,
) -> BoxStream<'static, StreamEvent>;
}