Skip to main content

agent_runtime/llm/provider/
llama.rs

1use async_trait::async_trait;
2use futures::stream::StreamExt;
3use reqwest::Client as HttpClient;
4use serde::{Deserialize, Serialize};
5use tokio::sync::mpsc;
6
7use super::super::{ChatClient, ChatRequest, ChatResponse, LlmError, LlmResult};
8
9/// Llama.cpp server client (local or remote)
10///
11/// Compatible with llama.cpp's OpenAI-compatible API server
12/// Typically runs on localhost:8080 or similar
13pub struct LlamaClient {
14    base_url: String,
15    model: String,
16    http_client: HttpClient,
17}
18
19impl LlamaClient {
20    /// Create a new llama.cpp client
21    ///
22    /// # Arguments
23    /// * `base_url` - Base URL of llama.cpp server (e.g., "http://localhost:8080")
24    /// * `model` - Model name (optional, llama.cpp usually ignores this)
25    pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
26        Self {
27            base_url: base_url.into(),
28            model: model.into(),
29            http_client: HttpClient::new(),
30        }
31    }
32
33    /// Create a new llama.cpp client with custom HTTP client
34    /// Useful for configuring TLS, timeouts, etc.
35    pub fn with_http_client(
36        base_url: impl Into<String>,
37        model: impl Into<String>,
38        http_client: HttpClient,
39    ) -> Self {
40        Self {
41            base_url: base_url.into(),
42            model: model.into(),
43            http_client,
44        }
45    }
46
47    /// Create a client pointing to localhost:8080 (default llama.cpp port)
48    pub fn localhost() -> Self {
49        Self::new("http://localhost:8080", "llama")
50    }
51
52    /// Create a client pointing to localhost with custom port
53    pub fn localhost_with_port(port: u16) -> Self {
54        Self::new(format!("http://localhost:{}", port), "llama")
55    }
56
57    /// Create a client with insecure HTTPS (accepts self-signed certificates)
58    /// Useful for local development with HTTPS servers
59    pub fn insecure(base_url: impl Into<String>, model: impl Into<String>) -> Self {
60        let http_client = HttpClient::builder()
61            .danger_accept_invalid_certs(true)
62            .build()
63            .expect("Failed to build HTTP client");
64
65        Self::with_http_client(base_url, model, http_client)
66    }
67
68    /// Create localhost client with insecure HTTPS on custom port
69    pub fn localhost_insecure(port: u16) -> Self {
70        Self::insecure(format!("https://localhost:{}", port), "llama")
71    }
72
73    /// Get the model name
74    pub fn model(&self) -> &str {
75        &self.model
76    }
77
78    /// Get the provider name
79    pub fn provider(&self) -> &str {
80        "llama.cpp"
81    }
82}
83
84#[async_trait]
85impl ChatClient for LlamaClient {
86    async fn chat(&self, request: ChatRequest) -> LlmResult<ChatResponse> {
87        let url = format!("{}/chat/completions", self.base_url);
88
89        // Build llama.cpp-compatible request
90        let llama_request = LlamaChatRequest {
91            model: self.model.clone(),
92            messages: request.messages,
93            temperature: request.temperature,
94            max_tokens: request.max_tokens,
95            top_p: request.top_p,
96            tools: request.tools,
97        };
98
99        // Send request
100        let response = self
101            .http_client
102            .post(&url)
103            .header("Content-Type", "application/json")
104            .json(&llama_request)
105            .send()
106            .await
107            .map_err(|e| LlmError::NetworkError(e.to_string()))?;
108
109        // Check status
110        let status = response.status();
111        if !status.is_success() {
112            let error_text = response.text().await.unwrap_or_default();
113            return Err(LlmError::ApiError(format!(
114                "Status {}: {}",
115                status, error_text
116            )));
117        }
118
119        // Parse response (same format as OpenAI)
120        let llama_response: LlamaChatResponse = response
121            .json()
122            .await
123            .map_err(|e| LlmError::ParseError(e.to_string()))?;
124
125        // Extract first choice
126        let choice = llama_response
127            .choices
128            .first()
129            .ok_or_else(|| LlmError::ParseError("No choices in response".to_string()))?;
130
131        // Convert llama.cpp tool_calls to our ToolCall type
132        let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
133            calls
134                .iter()
135                .map(|tc| super::super::types::ToolCall {
136                    id: tc.id.clone(),
137                    r#type: tc.r#type.clone(),
138                    function: super::super::types::FunctionCall {
139                        name: tc.function.name.clone(),
140                        arguments: tc.function.arguments.clone(),
141                    },
142                })
143                .collect()
144        });
145
146        Ok(ChatResponse {
147            content: choice.message.content.clone(),
148            model: llama_response.model.unwrap_or_else(|| self.model.clone()),
149            usage: llama_response.usage.map(|u| super::super::types::Usage {
150                prompt_tokens: u.prompt_tokens,
151                completion_tokens: u.completion_tokens,
152                total_tokens: u.total_tokens,
153            }),
154            finish_reason: choice.finish_reason.clone(),
155            tool_calls,
156        })
157    }
158
159    async fn chat_stream(
160        &self,
161        request: ChatRequest,
162        tx: mpsc::Sender<String>,
163    ) -> LlmResult<ChatResponse> {
164        let url = format!("{}/v1/chat/completions", self.base_url);
165
166        // Build llama.cpp-compatible request with streaming enabled
167        let llama_request = LlamaChatRequest {
168            model: self.model.clone(),
169            messages: request.messages.clone(),
170            temperature: request.temperature,
171            max_tokens: request.max_tokens,
172            top_p: request.top_p,
173            tools: request.tools.clone(),
174        };
175
176        // Send request with streaming
177        let response = self
178            .http_client
179            .post(&url)
180            .header("Content-Type", "application/json")
181            .header("Accept", "text/event-stream")
182            .json(&serde_json::json!({
183                "model": llama_request.model,
184                "messages": llama_request.messages,
185                "temperature": llama_request.temperature,
186                "max_tokens": llama_request.max_tokens,
187                "top_p": llama_request.top_p,
188                "tools": llama_request.tools,
189                "stream": true,
190            }))
191            .send()
192            .await
193            .map_err(|e| LlmError::NetworkError(e.to_string()))?;
194
195        if !response.status().is_success() {
196            let status = response.status();
197            let error_text = response.text().await.unwrap_or_default();
198            return Err(LlmError::ApiError(format!(
199                "HTTP {}: {}",
200                status, error_text
201            )));
202        }
203
204        // Convert byte stream to text chunks and send through channel
205        let mut stream = response.bytes_stream();
206        while let Some(chunk_result) = stream.next().await {
207            let bytes = chunk_result.map_err(|e| LlmError::NetworkError(e.to_string()))?;
208
209            // Parse SSE format: "data: {...}\n\n"
210            let text = String::from_utf8_lossy(&bytes);
211            for line in text.lines() {
212                if let Some(json_str) = line.strip_prefix("data: ") {
213                    if json_str.trim() == "[DONE]" {
214                        continue;
215                    }
216                    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
217                        if let Some(delta) = parsed
218                            .get("choices")
219                            .and_then(|c| c.get(0))
220                            .and_then(|c| c.get("delta"))
221                            .and_then(|d| d.get("content"))
222                            .and_then(|c| c.as_str())
223                        {
224                            let _ = tx.send(delta.to_string()).await;
225                        }
226                    }
227                }
228            }
229        }
230
231        // After streaming, make a regular call to get the full response with tool_calls
232        self.chat(request).await
233    }
234}
235
236// llama.cpp request/response types (OpenAI-compatible)
237
238#[derive(Debug, Serialize)]
239struct LlamaChatRequest {
240    model: String,
241    messages: Vec<super::super::types::ChatMessage>,
242
243    #[serde(skip_serializing_if = "Option::is_none")]
244    temperature: Option<f32>,
245
246    #[serde(skip_serializing_if = "Option::is_none")]
247    max_tokens: Option<u32>,
248
249    #[serde(skip_serializing_if = "Option::is_none")]
250    top_p: Option<f32>,
251
252    #[serde(skip_serializing_if = "Option::is_none")]
253    tools: Option<Vec<serde_json::Value>>,
254}
255
256#[derive(Debug, Deserialize)]
257struct LlamaChatResponse {
258    model: Option<String>,
259    choices: Vec<Choice>,
260    usage: Option<UsageInfo>,
261}
262
263#[derive(Debug, Deserialize)]
264struct Choice {
265    message: Message,
266    finish_reason: Option<String>,
267}
268
269#[derive(Debug, Deserialize)]
270struct Message {
271    #[serde(default)]
272    content: String,
273
274    #[serde(skip_serializing_if = "Option::is_none")]
275    tool_calls: Option<Vec<LlamaToolCall>>,
276}
277
278#[derive(Debug, Deserialize)]
279struct LlamaToolCall {
280    id: String,
281    r#type: String,
282    function: LlamaFunction,
283}
284
285#[derive(Debug, Deserialize)]
286struct LlamaFunction {
287    name: String,
288    arguments: String,
289}
290
291#[derive(Debug, Deserialize)]
292struct UsageInfo {
293    prompt_tokens: u32,
294    completion_tokens: u32,
295    total_tokens: u32,
296}