clawgarden-agent 0.3.5

Agent runtime with persona/memory loader, judge, and pi RPC for ClawGarden
Documentation
//! LLM API bridge — calls Z.AI (GLM) directly via OpenAI-compatible chat completions API.
//!
//! Reads ZAI_API_KEY from environment. Falls back to a simple echo if unavailable.

use anyhow::{Context, Result};
use clawgarden_proto::MessagePayload;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::time::timeout;

/// Response timeout for LLM API calls
const LLM_TIMEOUT_MS: u64 = 8000;

/// Z.AI API base URL (OpenAI-compatible)
const ZAI_API_BASE: &str = "https://api.z.ai/api/paas/v4";

/// Default model
const DEFAULT_MODEL: &str = "glm-4.5-flash";

// ── OpenAI-compatible request/response types ──────────────────────────────

// ── OpenAI-compatible request/response types ──────────────────────────────

#[derive(Debug, Serialize)]
struct SerChatMessage {
    role: String,
    content: String,
}

#[derive(Debug, Serialize)]
struct ChatRequest {
    model: String,
    messages: Vec<SerChatMessage>,
    max_tokens: u32,
    temperature: f32,
}

#[derive(Debug, Deserialize)]
struct ChatMessage {
    role: String,
    content: String,
    #[serde(default)]
    reasoning_content: Option<String>,
}

impl ChatMessage {
    /// Get the response content, falling back to reasoning if content is empty
    fn get_response(&self) -> String {
        if !self.content.is_empty() {
            self.content.clone()
        } else if let Some(ref reasoning) = self.reasoning_content {
            reasoning.trim().to_string()
        } else {
            String::new()
        }
    }
}

#[derive(Debug, Deserialize)]
struct ChatResponse {
    choices: Vec<ChatChoice>,
}

#[derive(Debug, Deserialize)]
struct ChatChoice {
    message: ChatMessage,
}

// ── RPC request (kept for backward compat) ────────────────────────────────

/// RPC request — fields used to build the LLM prompt
#[derive(Debug, Serialize, Deserialize)]
pub struct PiRpcRequest {
    pub agent_name: String,
    pub persona: String,
    pub memory: String,
    pub conversation_id: String,
    pub correlation_id: String,
    pub content: String,
    pub recent_messages: Vec<String>,
}

// ── Public API ────────────────────────────────────────────────────────────

/// Call LLM API and get a response
pub async fn call_pi_rpc(request: PiRpcRequest) -> Result<MessagePayload> {
    call_llm_api(&request).await
}

/// Call LLM API with timeout handling
pub async fn call_pi_rpc_safe(request: PiRpcRequest) -> Result<MessagePayload> {
    let result = timeout(
        Duration::from_millis(LLM_TIMEOUT_MS + 500),
        call_pi_rpc(request),
    )
    .await;

    match result {
        Ok(Ok(payload)) => Ok(payload),
        Ok(Err(e)) => {
            log::error!("LLM API call failed: {}", e);
            Err(e)
        }
        Err(_) => {
            log::error!("LLM API call timed out after {}ms", LLM_TIMEOUT_MS);
            Err(anyhow::anyhow!("LLM API timeout"))
        }
    }
}

// ── Implementation ────────────────────────────────────────────────────────

async fn call_llm_api(request: &PiRpcRequest) -> Result<MessagePayload> {
    let api_key = std::env::var("ZAI_API_KEY")
        .or_else(|_| std::env::var("Z_AI_API_KEY"))
        .context("ZAI_API_KEY not set in environment")?;

    // Build system prompt from persona + memory
    let mut system_content = String::new();
    if !request.persona.is_empty() {
        system_content.push_str(&request.persona);
        system_content.push('\n');
    }
    if !request.memory.is_empty() {
        system_content.push_str("\nMemory:\n");
        system_content.push_str(&request.memory);
        system_content.push('\n');
    }
    if system_content.is_empty() {
        system_content = format!(
            "You are {}, a helpful AI assistant in a group chat. Respond concisely in the same language as the user.",
            request.agent_name
        );
    }

    // Build messages
    let mut messages = vec![SerChatMessage {
        role: "system".to_string(),
        content: system_content,
    }];

    // Add recent messages as context
    for msg in &request.recent_messages {
        messages.push(SerChatMessage {
            role: "assistant".to_string(),
            content: msg.clone(),
        });
    }

    // Add the current user message
    messages.push(SerChatMessage {
        role: "user".to_string(),
        content: request.content.clone(),
    });

    let chat_request = ChatRequest {
        model: DEFAULT_MODEL.to_string(),
        messages,
        max_tokens: 1024,
        temperature: 0.7,
    };

    let client = Client::builder()
        .timeout(Duration::from_millis(LLM_TIMEOUT_MS))
        .build()?;

    let url = format!("{}/chat/completions", ZAI_API_BASE);

    let response = client
        .post(&url)
        .header("Authorization", format!("Bearer {}", api_key))
        .header("Content-Type", "application/json")
        .json(&chat_request)
        .send()
        .await
        .context("LLM API HTTP request failed")?;

    if !response.status().is_success() {
        let status = response.status();
        let body = response.text().await.unwrap_or_default();
        anyhow::bail!("LLM API error {}: {}", status, body);
    }

    let chat_response: ChatResponse = response
        .json()
        .await
        .context("Failed to parse LLM API response")?;

    let content = chat_response
        .choices
        .first()
        .map(|c| c.message.get_response())
        .filter(|s| !s.is_empty())
        .unwrap_or_else(|| "(no response)".to_string());

    Ok(MessagePayload {
        content,
        context: vec![],
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_rpc_request_serialization() {
        let request = PiRpcRequest {
            agent_name: "alex".to_string(),
            persona: "Role: pm".to_string(),
            memory: String::new(),
            conversation_id: "conv_1".to_string(),
            correlation_id: "req_1".to_string(),
            content: "Hello".to_string(),
            recent_messages: vec![],
        };

        let json = serde_json::to_string(&request).unwrap();
        let restored: PiRpcRequest = serde_json::from_str(&json).unwrap();
        assert_eq!(restored.agent_name, "alex");
        assert_eq!(restored.content, "Hello");
    }

    #[test]
    fn test_chat_request_serialization() {
        let req = ChatRequest {
            model: "glm-4.5-flash".to_string(),
            messages: vec![
                SerChatMessage {
                    role: "system".to_string(),
                    content: "You are helpful.".to_string(),
                },
                SerChatMessage {
                    role: "user".to_string(),
                    content: "Hi".to_string(),
                },
            ],
            max_tokens: 1024,
            temperature: 0.7,
        };
        let json = serde_json::to_string(&req).unwrap();
        assert!(json.contains("glm-4.5-flash"));
        assert!(json.contains("Hi"));
    }
}