use std::io::{self, Write};
use anyhow::{Context, Result};
use async_openai::{config::OpenAIConfig, types::CreateChatCompletionRequest, Client};
use futures_util::StreamExt;
use tokio::signal::unix::{signal, SignalKind};
use tracing::{debug, info, warn};
use crate::cli::color::red;
use crate::cli::jarvis::{
jarvis_print_chunk, jarvis_print_end, jarvis_print_prefix, jarvis_spinner,
};
use super::tools::call::{accumulate_tool_call, ToolCallAccumulator};
pub struct StreamResult {
pub full_text: String,
pub tool_calls: Vec<ToolCallAccumulator>,
pub interrupted: bool,
}
pub async fn process_stream(
client: &Client<OpenAIConfig>,
request: CreateChatCompletionRequest,
is_first_round: bool,
) -> Result<StreamResult> {
let mut sigint =
signal(SignalKind::interrupt()).context("Failed to register SIGINT handler")?;
let spinner = jarvis_spinner();
let chat = client.chat();
let mut stream = tokio::select! {
result = chat.create_stream(request) => {
match result {
Ok(s) => s,
Err(e) => {
spinner.finish_and_clear();
return Err(anyhow::anyhow!(e).context("Failed to create chat stream"));
}
}
}
_ = sigint.recv() => {
info!("Ctrl-C received while waiting for API connection, interrupting");
spinner.finish_and_clear();
return Ok(StreamResult {
full_text: String::new(),
tool_calls: vec![],
interrupted: true,
});
}
};
debug!("Stream created successfully, starting to process chunks");
let mut full_text = String::new();
let mut tool_calls: Vec<ToolCallAccumulator> = Vec::new();
let mut started_text = false;
let mut spinner_cleared = false;
let mut chunk_count: u32 = 0;
let mut interrupted = false;
loop {
tokio::select! {
chunk = stream.next() => {
let result = match chunk {
Some(r) => r,
None => break, };
chunk_count += 1;
let response = match result {
Ok(r) => r,
Err(e) => {
warn!(
error = %e,
chunks_received = chunk_count,
text_so_far_len = full_text.len(),
"Stream error occurred"
);
if !spinner_cleared {
spinner.finish_and_clear();
}
if started_text {
jarvis_print_end();
}
anyhow::bail!("Stream error: {e}");
}
};
for choice in &response.choices {
let delta = &choice.delta;
if let Some(ref content) = delta.content {
debug!(
chunk = chunk_count,
content_length = content.len(),
has_content = true,
content = %content,
"Received text chunk"
);
if !started_text {
if !spinner_cleared {
spinner.finish_and_clear();
spinner_cleared = true;
}
jarvis_print_prefix();
started_text = true;
}
jarvis_print_chunk(content);
let _ = io::stdout().flush();
full_text.push_str(content);
}
if let Some(ref tc_chunks) = delta.tool_calls {
if !spinner_cleared {
spinner.finish_and_clear();
spinner_cleared = true;
}
debug!(
chunk = chunk_count,
tool_call_chunks = tc_chunks.len(),
"Received tool call chunk"
);
for chunk in tc_chunks {
accumulate_tool_call(&mut tool_calls, chunk);
}
}
if delta.content.is_none() && delta.tool_calls.is_none() {
debug!(
chunk = chunk_count,
role = ?delta.role,
"Received chunk with no content and no tool_calls"
);
}
}
}
_ = sigint.recv() => {
info!(
chunks_received = chunk_count,
text_so_far_len = full_text.len(),
"Ctrl-C received during AI streaming, interrupting"
);
interrupted = true;
break;
}
}
}
if !spinner_cleared {
spinner.finish_and_clear();
}
if started_text {
if interrupted {
jarvis_print_chunk(&red(" [interrupted]"));
let _ = io::stdout().flush();
}
jarvis_print_end();
}
debug!(
total_chunks = chunk_count,
full_text_length = full_text.len(),
tool_calls_count = tool_calls.len(),
started_text = started_text,
is_first_round = is_first_round,
interrupted = interrupted,
"Stream processing completed"
);
Ok(StreamResult {
full_text,
tool_calls,
interrupted,
})
}