use crate::api::types::IncomingMessage;
use crate::llm::{self, ChatMessage, DEFAULT_MAX_TURNS, DEFAULT_SYSTEM_PROMPT};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use toq_core::config::HandlerEntry;
type ThreadState = HashMap<String, (Vec<ChatMessage>, usize)>;
pub struct LlmHandler {
threads: Arc<Mutex<ThreadState>>,
api_url: String,
}
impl LlmHandler {
pub fn new(api_url: String) -> Self {
Self {
threads: Arc::new(Mutex::new(HashMap::new())),
api_url,
}
}
pub fn on_thread_close(&self, thread_id: &str) {
let threads = self.threads.clone();
let tid = thread_id.to_string();
tokio::spawn(async move {
threads.lock().await.remove(&tid);
});
}
pub fn dispatch(&self, handler: &HandlerEntry, msg: &IncomingMessage) {
let thread_id = msg.thread_id.clone().unwrap_or_else(|| msg.id.clone());
let from = msg.from.clone();
let text = msg
.body
.as_ref()
.and_then(|b| b.get("text"))
.and_then(|t| t.as_str())
.unwrap_or("")
.to_string();
let prompt = if let Some(ref file) = handler.prompt_file {
std::fs::read_to_string(file).unwrap_or_else(|_| DEFAULT_SYSTEM_PROMPT.to_string())
} else {
handler
.prompt
.clone()
.unwrap_or_else(|| DEFAULT_SYSTEM_PROMPT.to_string())
};
let max_turns = handler.max_turns.unwrap_or(if handler.auto_close {
usize::MAX
} else {
DEFAULT_MAX_TURNS
});
let auto_close = handler.auto_close;
let provider = handler.provider.clone();
let model = handler.model.clone();
let api_url = self.api_url.clone();
let handler_name = handler.name.clone();
let threads = self.threads.clone();
tokio::spawn(async move {
let (history_snapshot, turn_count, is_last_turn) = {
let mut guard = threads.lock().await;
let (history, count) = guard
.entry(thread_id.clone())
.or_insert_with(|| (Vec::new(), 0));
if *count == usize::MAX {
tracing::debug!("handler {handler_name}: thread {thread_id} already closed");
return;
}
history.push(ChatMessage {
role: "user".into(),
content: text,
});
*count += 1;
let is_last = *count >= max_turns;
(history.clone(), *count, is_last)
};
let include_close_tool = auto_close && !is_last_turn && turn_count >= 2;
let full_prompt = if max_turns < usize::MAX {
if is_last_turn {
format!(
"{prompt}\n\nThis is your final response in this conversation (turn {turn_count} of {max_turns}). Wrap up naturally."
)
} else if auto_close {
format!(
"{prompt}\n\nYou are on turn {turn_count} of {max_turns} in this conversation. When the conversation has been fully explored, use the close_thread tool to end it. Do not close prematurely."
)
} else {
format!(
"{prompt}\n\nYou are on turn {turn_count} of {max_turns} in this conversation."
)
}
} else if auto_close {
format!(
"{prompt}\n\nWhen the conversation has been fully explored and both sides have shared their thoughts, use the close_thread tool to end it. Do not close prematurely."
)
} else {
prompt.clone()
};
let result = llm::call(
&provider,
&model,
&full_prompt,
&history_snapshot,
include_close_tool,
)
.await;
match result {
Ok(resp) => {
let (safe_text, was_redacted) = llm::redact::redact(&resp.text);
if was_redacted {
tracing::warn!(
"handler {handler_name}: redacted credential pattern from LLM response"
);
}
let close = resp.close_thread || is_last_turn;
{
let mut guard = threads.lock().await;
if let Some((history, count)) = guard.get_mut(&thread_id) {
if !safe_text.is_empty() {
history.push(ChatMessage {
role: "assistant".into(),
content: safe_text.clone(),
});
}
if close {
*count = usize::MAX;
}
}
}
let mut body = serde_json::json!({
"to": from,
"body": {"text": safe_text},
"thread_id": thread_id,
});
if close {
body["close_thread"] = serde_json::json!(true);
}
let url = format!("{api_url}/v1/messages?wait=true");
match reqwest::Client::new().post(&url).json(&body).send().await {
Ok(r) if r.status().is_success() => {
tracing::info!(
"handler {handler_name}: replied to {from} (turn {turn_count}{})",
if close { ", closing" } else { "" }
);
}
Ok(r) => {
tracing::warn!(
"handler {handler_name}: reply failed: HTTP {}",
r.status()
);
}
Err(e) => {
tracing::warn!("handler {handler_name}: reply failed: {e}");
}
}
}
Err(e) => {
tracing::warn!("handler {handler_name}: LLM call failed: {e}");
}
}
});
}
}