use super::types::{StreamMsg, ToolResultMsg};
use crate::command::chat::agent::run_main_agent_loop;
use crate::command::chat::agent_config::{AgentLoopConfig, AgentLoopSharedState};
use crate::command::chat::error::ChatError;
use crate::command::chat::storage::ChatMessage;
use async_openai::types::chat::ChatCompletionTools;
use std::sync::{Arc, mpsc};
use tokio_util::sync::CancellationToken;
#[derive(Debug)]
pub struct MainAgentHandle {
pub stream_rx: mpsc::Receiver<StreamMsg>,
pub cancel_token: CancellationToken,
}
impl MainAgentHandle {
pub fn spawn(
config: AgentLoopConfig,
shared: AgentLoopSharedState,
api_messages: Vec<ChatMessage>,
tools: Vec<ChatCompletionTools>,
system_prompt_fn: Arc<dyn Fn() -> Option<String> + Send + Sync>,
) -> (Self, mpsc::SyncSender<ToolResultMsg>) {
let (stream_tx, stream_rx) = mpsc::channel::<StreamMsg>();
let (tool_result_tx, tool_result_rx) = mpsc::sync_channel::<ToolResultMsg>(16);
let cancel_token = config.cancel_token.clone();
std::thread::spawn(move || {
let stream_tx_panic = stream_tx.clone();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
let runtime = match tokio::runtime::Runtime::new() {
Ok(rt) => rt,
Err(e) => {
let _ = stream_tx
.send(StreamMsg::Error(ChatError::RuntimeFailed(e.to_string())));
return;
}
};
runtime.block_on(run_main_agent_loop(
config,
shared,
api_messages,
tools,
system_prompt_fn,
stream_tx,
tool_result_rx,
));
}));
if let Err(panic_info) = result {
let panic_msg = if let Some(s) = panic_info.downcast_ref::<&str>() {
format!("Agent 线程 panic: {}", s)
} else if let Some(s) = panic_info.downcast_ref::<String>() {
format!("Agent 线程 panic: {}", s)
} else {
"Agent 线程发生未知 panic".to_string()
};
crate::util::log::write_error_log("MainAgentHandle::spawn", &panic_msg);
let _ = stream_tx_panic.send(StreamMsg::Error(ChatError::AgentPanic(panic_msg)));
}
});
(
MainAgentHandle {
stream_rx,
cancel_token,
},
tool_result_tx,
)
}
pub fn cancel(&self) {
self.cancel_token.cancel();
}
pub fn poll(&self) -> Vec<StreamMsg> {
let mut msgs = Vec::new();
loop {
match self.stream_rx.try_recv() {
Ok(msg) => msgs.push(msg),
Err(mpsc::TryRecvError::Empty) => break,
Err(mpsc::TryRecvError::Disconnected) => {
msgs.push(StreamMsg::Error(ChatError::Other(
"Agent 通道已断开(agent 线程异常退出)".to_string(),
)));
break;
}
}
}
msgs
}
}