use anyhow::{Context, Result};
use clawgarden_proto::MessagePayload;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::time::timeout;
const LLM_TIMEOUT_MS: u64 = 8000;
const ZAI_API_BASE: &str = "https://api.z.ai/api/paas/v4";
const DEFAULT_MODEL: &str = "glm-4-flash";
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<ChatMessage>,
max_tokens: u32,
temperature: f32,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<ChatChoice>,
}
#[derive(Debug, Deserialize)]
struct ChatChoice {
message: ChatMessage,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PiRpcRequest {
pub agent_name: String,
pub persona: String,
pub memory: String,
pub conversation_id: String,
pub correlation_id: String,
pub content: String,
pub recent_messages: Vec<String>,
}
pub async fn call_pi_rpc(request: PiRpcRequest) -> Result<MessagePayload> {
call_llm_api(&request).await
}
pub async fn call_pi_rpc_safe(request: PiRpcRequest) -> Result<MessagePayload> {
let result = timeout(
Duration::from_millis(LLM_TIMEOUT_MS + 500),
call_pi_rpc(request),
)
.await;
match result {
Ok(Ok(payload)) => Ok(payload),
Ok(Err(e)) => {
log::error!("LLM API call failed: {}", e);
Err(e)
}
Err(_) => {
log::error!("LLM API call timed out after {}ms", LLM_TIMEOUT_MS);
Err(anyhow::anyhow!("LLM API timeout"))
}
}
}
async fn call_llm_api(request: &PiRpcRequest) -> Result<MessagePayload> {
let api_key = std::env::var("ZAI_API_KEY")
.or_else(|_| std::env::var("Z_AI_API_KEY"))
.context("ZAI_API_KEY not set in environment")?;
let mut system_content = String::new();
if !request.persona.is_empty() {
system_content.push_str(&request.persona);
system_content.push('\n');
}
if !request.memory.is_empty() {
system_content.push_str("\nMemory:\n");
system_content.push_str(&request.memory);
system_content.push('\n');
}
if system_content.is_empty() {
system_content = format!(
"You are {}, a helpful AI assistant in a group chat. Respond concisely in the same language as the user.",
request.agent_name
);
}
let mut messages = vec![ChatMessage {
role: "system".to_string(),
content: system_content,
}];
for msg in &request.recent_messages {
messages.push(ChatMessage {
role: "assistant".to_string(),
content: msg.clone(),
});
}
messages.push(ChatMessage {
role: "user".to_string(),
content: request.content.clone(),
});
let chat_request = ChatRequest {
model: DEFAULT_MODEL.to_string(),
messages,
max_tokens: 256,
temperature: 0.7,
};
let client = Client::builder()
.timeout(Duration::from_millis(LLM_TIMEOUT_MS))
.build()?;
let url = format!("{}/chat/completions", ZAI_API_BASE);
let response = client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&chat_request)
.send()
.await
.context("LLM API HTTP request failed")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("LLM API error {}: {}", status, body);
}
let chat_response: ChatResponse = response
.json()
.await
.context("Failed to parse LLM API response")?;
let content = chat_response
.choices
.first()
.map(|c| c.message.content.clone())
.unwrap_or_else(|| "(no response)".to_string());
Ok(MessagePayload {
content,
context: vec![],
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rpc_request_serialization() {
let request = PiRpcRequest {
agent_name: "alex".to_string(),
persona: "Role: pm".to_string(),
memory: String::new(),
conversation_id: "conv_1".to_string(),
correlation_id: "req_1".to_string(),
content: "Hello".to_string(),
recent_messages: vec![],
};
let json = serde_json::to_string(&request).unwrap();
let restored: PiRpcRequest = serde_json::from_str(&json).unwrap();
assert_eq!(restored.agent_name, "alex");
assert_eq!(restored.content, "Hello");
}
#[test]
fn test_chat_request_serialization() {
let req = ChatRequest {
model: "glm-4-flash".to_string(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: "You are helpful.".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: "Hi".to_string(),
},
],
max_tokens: 256,
temperature: 0.7,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("glm-4-flash"));
assert!(json.contains("Hi"));
}
}