use std::sync::{Arc, Mutex};
use tokio::sync::mpsc;
use super::types::{
ChatResponseSharedState, ResponseEvent, StreamChunk, StreamError, ToolCallEvent,
};
use crate::types::Step;
#[derive(Debug)]
pub struct WriterError {
pub message: String,
}
impl WriterError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for WriterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for WriterError {}
impl<T> From<mpsc::error::SendError<T>> for WriterError {
fn from(err: mpsc::error::SendError<T>) -> Self {
Self {
message: format!("channel send failed: {err}"),
}
}
}
pub struct ChatResponseWriter {
pub(crate) text_tx: mpsc::Sender<String>,
pub(crate) thought_tx: mpsc::Sender<String>,
pub(crate) tool_call_tx: mpsc::Sender<ToolCallEvent>,
pub(crate) error_tx: mpsc::Sender<StreamError>,
pub(crate) event_tx: mpsc::Sender<ResponseEvent>,
pub(crate) step_tx: mpsc::Sender<Step>,
pub(crate) chunk_tx: mpsc::Sender<StreamChunk>,
pub(crate) shared_state: Arc<Mutex<ChatResponseSharedState>>,
}
impl ChatResponseWriter {
pub async fn send_text(&self, text: String) -> Result<(), WriterError> {
self.text_tx.send(text).await.map_err(WriterError::from)
}
pub async fn send_thought(&self, thought: String) -> Result<(), WriterError> {
self.thought_tx
.send(thought)
.await
.map_err(WriterError::from)
}
pub async fn send_tool_call(&self, event: ToolCallEvent) -> Result<(), WriterError> {
self.tool_call_tx
.send(event)
.await
.map_err(WriterError::from)
}
pub async fn send_error(&self, error: StreamError) -> Result<(), WriterError> {
self.error_tx.send(error).await.map_err(WriterError::from)
}
pub async fn send_event(&self, event: ResponseEvent) -> Result<(), WriterError> {
self.event_tx.send(event).await.map_err(WriterError::from)
}
pub async fn send_step(&self, step: crate::types::Step) -> Result<(), WriterError> {
self.step_tx.send(step).await.map_err(WriterError::from)
}
pub async fn send_chunk(&self, chunk: StreamChunk) -> Result<(), WriterError> {
self.chunk_tx.send(chunk).await.map_err(WriterError::from)
}
pub fn set_usage(&self, usage: crate::types::UsageMetadata) {
match self.shared_state.lock() {
Ok(mut state) => {
state.usage = Some(usage);
}
Err(e) => {
tracing::error!(
error = %e,
"ChatResponseWriter shared_state mutex poisoned in set_usage"
);
}
}
}
pub fn set_structured_output(&self, value: serde_json::Value) {
match self.shared_state.lock() {
Ok(mut state) => {
state.structured_output = Some(value);
}
Err(e) => {
tracing::error!(
error = %e,
"ChatResponseWriter shared_state mutex poisoned in set_structured_output"
);
}
}
}
}