use crate::ui::llm::LlmProvider;
use crate::ui::prompt_builder;
use crate::ui::state::AppState;
use crate::ui::storage::ImageStore;
use crate::ui::tool_bridge::{TOOL_DEFS, dispatch};
use axum::Json;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use base64::Engine;
use serde::Deserialize;
use serde_json::{Value, json};
use std::sync::Arc;
const MAX_TOOL_ROUNDS: usize = 16;
const MAX_IMAGES_PER_MESSAGE: usize = 4;
#[derive(Deserialize)]
pub struct ChatBody {
pub messages: Vec<Value>,
#[serde(default)]
pub current_card: Option<Value>,
#[serde(default)]
pub current_preset: Option<String>,
#[serde(default)]
pub current_data: Option<Value>,
#[serde(default = "default_use_tools")]
pub use_tools: bool,
#[serde(default)]
pub image_ids: Vec<String>,
}
fn default_use_tools() -> bool {
true
}
async fn inject_images(
messages: &mut [Value],
image_ids: &[String],
store: &crate::ui::storage::local::LocalImageStore,
) -> Result<(), axum::response::Response> {
if image_ids.is_empty() {
return Ok(());
}
if image_ids.len() > MAX_IMAGES_PER_MESSAGE {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({"error": format!(
"too many images (max {MAX_IMAGES_PER_MESSAGE} per message)"
)})),
)
.into_response());
}
let mut image_parts: Vec<Value> = Vec::with_capacity(image_ids.len());
for id in image_ids {
let (_meta, data) = store.load(id).await.map_err(|_| {
(
StatusCode::NOT_FOUND,
Json(json!({"error": format!("image not found: {id}")})),
)
.into_response()
})?;
let b64 = base64::engine::general_purpose::STANDARD.encode(&data);
image_parts.push(json!({
"type": "image_url",
"image_url": { "url": format!("data:image/png;base64,{b64}") }
}));
}
if let Some(last_user) = messages
.iter_mut()
.rev()
.find(|m| m.get("role").and_then(Value::as_str) == Some("user"))
{
let text = last_user
.get("content")
.and_then(Value::as_str)
.unwrap_or("")
.to_string();
image_parts.push(json!({"type": "text", "text": text}));
last_user["content"] = Value::Array(image_parts);
}
Ok(())
}
pub async fn post_chat(
State(state): State<Arc<AppState>>,
Json(body): Json<ChatBody>,
) -> impl IntoResponse {
let mut messages = build_messages_with_state(
&body.messages,
body.current_card.as_ref(),
body.current_preset.as_deref(),
body.current_data.as_ref(),
);
if let Err(resp) = inject_images(&mut messages, &body.image_ids, &state.image_store).await {
return resp;
}
let has_images = !body.image_ids.is_empty();
if !body.use_tools {
return legacy_chat(&state, &messages).await;
}
tool_calling_loop(&state, messages, has_images).await
}
async fn tool_calling_loop(
state: &Arc<AppState>,
initial_messages: Vec<Value>,
has_images: bool,
) -> axum::response::Response {
let user_query: String = initial_messages
.iter()
.rev()
.find(|m| m.get("role").and_then(Value::as_str) == Some("user"))
.and_then(|m| {
let c = m.get("content")?;
c.as_str().map(String::from).or_else(|| {
c.as_array()?.iter().find_map(|part| {
if part.get("type")?.as_str()? == "text" {
part.get("text")?.as_str().map(String::from)
} else {
None
}
})
})
})
.unwrap_or_default();
let system_prompt = if user_query.is_empty() {
state.system_prompt.clone()
} else {
prompt_builder::build_system_prompt_with_query(&state.knowledge_base, None, &user_query)
};
let system_prompt = if has_images {
format!("{system_prompt}\n\n{}", prompt_builder::VISION_INSTRUCTION)
} else {
system_prompt
};
let mut messages: Vec<Value> = Vec::with_capacity(initial_messages.len() + 1);
messages.push(json!({"role": "system", "content": system_prompt}));
messages.extend(initial_messages);
let mut tool_log: Vec<Value> = Vec::new();
for round in 0..MAX_TOOL_ROUNDS {
let response = match state
.llm
.chat_with_tools(messages.clone(), TOOL_DEFS.as_slice(), Some("auto"))
.await
{
Ok(r) => r,
Err(e) => return internal_err(e),
};
let Some(choice) = response.choices.first() else {
return internal_err(anyhow::anyhow!("OpenAI returned no choices"));
};
messages.push(serde_json::to_value(&choice.message).unwrap_or_else(|_| json!({})));
if !choice.message.tool_calls.is_empty() && choice.finish_reason == "tool_calls" {
for call in &choice.message.tool_calls {
let result_str = dispatch(
&state.knowledge_base,
&call.function.name,
&call.function.arguments,
)
.await;
tool_log.push(json!({
"round": round,
"tool": call.function.name,
"args": call.function.arguments,
"result": result_str,
}));
messages.push(json!({
"role": "tool",
"tool_call_id": call.id,
"content": result_str,
}));
}
continue;
}
return Json(json!({
"content": choice.message.content.clone().unwrap_or_default(),
"tool_log": tool_log,
}))
.into_response();
}
Json(json!({
"content": "Error: exceeded maximum tool-calling rounds",
"tool_log": tool_log,
}))
.into_response()
}
async fn legacy_chat(state: &Arc<AppState>, messages: &[Value]) -> axum::response::Response {
match crate::ui::llm::legacy_chat_json(&state.llm, &state.system_prompt, messages).await {
Ok(content) => Json(json!({ "content": content, "tool_log": [] })).into_response(),
Err(e) => internal_err(e),
}
}
fn internal_err(e: anyhow::Error) -> axum::response::Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": e.to_string()})),
)
.into_response()
}
fn build_messages_with_state(
messages: &[Value],
current_card: Option<&Value>,
current_preset: Option<&str>,
current_data: Option<&Value>,
) -> Vec<Value> {
if current_card.is_none() && current_preset.is_none() && current_data.is_none() {
return messages.to_vec();
}
let mut parts = String::from("CURRENT CARD STATE (for MODIFY mode):\n");
if let Some(preset) = current_preset {
parts.push_str(&format!("Preset: {preset}\n"));
}
if let Some(data) = current_data {
parts.push_str("Data:\n");
parts.push_str(&serde_json::to_string_pretty(data).unwrap_or_default());
parts.push('\n');
}
if let Some(card) = current_card {
parts.push_str("Rendered card JSON:\n");
parts.push_str(&serde_json::to_string_pretty(card).unwrap_or_default());
parts.push('\n');
}
parts.push_str(
"\nIf the user's request is an EDIT of the above, modify the existing card. \
Otherwise generate a new card.",
);
let context_message = json!({ "role": "user", "content": parts });
let mut out = Vec::with_capacity(messages.len() + 1);
let split = messages.len().saturating_sub(1);
out.extend(messages[..split].iter().cloned());
out.push(context_message);
out.extend(messages[split..].iter().cloned());
out
}