use crate::ui::openai::{self, ChatRequest};
use crate::ui::prompt_builder;
use crate::ui::state::AppState;
use crate::ui::tool_bridge::{TOOL_DEFS, dispatch};
use axum::Json;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde::Deserialize;
use serde_json::{Value, json};
use std::sync::Arc;
const MAX_TOOL_ROUNDS: usize = 16;
#[derive(Deserialize)]
pub struct ChatBody {
pub messages: Vec<Value>,
#[serde(default)]
pub current_card: Option<Value>,
#[serde(default)]
pub current_preset: Option<String>,
#[serde(default)]
pub current_data: Option<Value>,
#[serde(default = "default_use_tools")]
pub use_tools: bool,
}
fn default_use_tools() -> bool {
true
}
pub async fn post_chat(
State(state): State<Arc<AppState>>,
Json(body): Json<ChatBody>,
) -> impl IntoResponse {
let messages = build_messages_with_state(
&body.messages,
body.current_card.as_ref(),
body.current_preset.as_deref(),
body.current_data.as_ref(),
);
if !body.use_tools {
return legacy_chat(&state, &messages).await;
}
tool_calling_loop(&state, messages).await
}
async fn tool_calling_loop(
state: &Arc<AppState>,
initial_messages: Vec<Value>,
) -> axum::response::Response {
let user_query = initial_messages
.iter()
.rev()
.find(|m| m.get("role").and_then(Value::as_str) == Some("user"))
.and_then(|m| m.get("content").and_then(Value::as_str))
.unwrap_or("");
let system_prompt = if user_query.is_empty() {
state.system_prompt.clone()
} else {
prompt_builder::build_system_prompt_with_query(&state.knowledge_base, None, user_query)
};
let mut messages: Vec<Value> = Vec::with_capacity(initial_messages.len() + 1);
messages.push(json!({"role": "system", "content": system_prompt}));
messages.extend(initial_messages);
let mut tool_log: Vec<Value> = Vec::new();
for round in 0..MAX_TOOL_ROUNDS {
let request = ChatRequest {
model: &state.model,
messages: messages.clone(),
tools: Some(TOOL_DEFS.as_slice()),
tool_choice: Some("auto"),
response_format: None,
};
let response = match openai::chat_with_tools(&state.openai_api_key, request).await {
Ok(r) => r,
Err(e) => return internal_err(e),
};
let Some(choice) = response.choices.first() else {
return internal_err(anyhow::anyhow!("OpenAI returned no choices"));
};
messages.push(serde_json::to_value(&choice.message).unwrap_or_else(|_| json!({})));
if !choice.message.tool_calls.is_empty() && choice.finish_reason == "tool_calls" {
for call in &choice.message.tool_calls {
let result_str = dispatch(
&state.knowledge_base,
&call.function.name,
&call.function.arguments,
)
.await;
tool_log.push(json!({
"round": round,
"tool": call.function.name,
"args": call.function.arguments,
"result": result_str,
}));
messages.push(json!({
"role": "tool",
"tool_call_id": call.id,
"content": result_str,
}));
}
continue;
}
return Json(json!({
"content": choice.message.content.clone().unwrap_or_default(),
"tool_log": tool_log,
}))
.into_response();
}
Json(json!({
"content": "Error: exceeded maximum tool-calling rounds",
"tool_log": tool_log,
}))
.into_response()
}
async fn legacy_chat(state: &Arc<AppState>, messages: &[Value]) -> axum::response::Response {
match openai::chat_json(
&state.openai_api_key,
&state.model,
&state.system_prompt,
messages,
)
.await
{
Ok(content) => Json(json!({ "content": content, "tool_log": [] })).into_response(),
Err(e) => internal_err(e),
}
}
fn internal_err(e: anyhow::Error) -> axum::response::Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": e.to_string()})),
)
.into_response()
}
fn build_messages_with_state(
messages: &[Value],
current_card: Option<&Value>,
current_preset: Option<&str>,
current_data: Option<&Value>,
) -> Vec<Value> {
if current_card.is_none() && current_preset.is_none() && current_data.is_none() {
return messages.to_vec();
}
let mut parts = String::from("CURRENT CARD STATE (for MODIFY mode):\n");
if let Some(preset) = current_preset {
parts.push_str(&format!("Preset: {preset}\n"));
}
if let Some(data) = current_data {
parts.push_str("Data:\n");
parts.push_str(&serde_json::to_string_pretty(data).unwrap_or_default());
parts.push('\n');
}
if let Some(card) = current_card {
parts.push_str("Rendered card JSON:\n");
parts.push_str(&serde_json::to_string_pretty(card).unwrap_or_default());
parts.push('\n');
}
parts.push_str(
"\nIf the user's request is an EDIT of the above, modify the existing card. \
Otherwise generate a new card.",
);
let context_message = json!({ "role": "user", "content": parts });
let mut out = Vec::with_capacity(messages.len() + 1);
let split = messages.len().saturating_sub(1);
out.extend(messages[..split].iter().cloned());
out.push(context_message);
out.extend(messages[split..].iter().cloned());
out
}