Skip to main content

agent_runtime/llm/provider/
llama.rs

1use async_trait::async_trait;
2use futures::stream::StreamExt;
3use reqwest::Client as HttpClient;
4use serde::{Deserialize, Serialize};
5
6use super::super::{ChatClient, ChatRequest, ChatResponse, LlmError, LlmResult, TextStream};
7
8/// Llama.cpp server client (local or remote)
9///
10/// Compatible with llama.cpp's OpenAI-compatible API server
11/// Typically runs on localhost:8080 or similar
12pub struct LlamaClient {
13    base_url: String,
14    model: String,
15    http_client: HttpClient,
16}
17
18impl LlamaClient {
19    /// Create a new llama.cpp client
20    ///
21    /// # Arguments
22    /// * `base_url` - Base URL of llama.cpp server (e.g., "http://localhost:8080")
23    /// * `model` - Model name (optional, llama.cpp usually ignores this)
24    pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
25        Self {
26            base_url: base_url.into(),
27            model: model.into(),
28            http_client: HttpClient::new(),
29        }
30    }
31
32    /// Create a new llama.cpp client with custom HTTP client
33    /// Useful for configuring TLS, timeouts, etc.
34    pub fn with_http_client(
35        base_url: impl Into<String>,
36        model: impl Into<String>,
37        http_client: HttpClient,
38    ) -> Self {
39        Self {
40            base_url: base_url.into(),
41            model: model.into(),
42            http_client,
43        }
44    }
45
46    /// Create a client pointing to localhost:8080 (default llama.cpp port)
47    pub fn localhost() -> Self {
48        Self::new("http://localhost:8080", "llama")
49    }
50
51    /// Create a client pointing to localhost with custom port
52    pub fn localhost_with_port(port: u16) -> Self {
53        Self::new(format!("http://localhost:{}", port), "llama")
54    }
55
56    /// Create a client with insecure HTTPS (accepts self-signed certificates)
57    /// Useful for local development with HTTPS servers
58    pub fn insecure(base_url: impl Into<String>, model: impl Into<String>) -> Self {
59        let http_client = HttpClient::builder()
60            .danger_accept_invalid_certs(true)
61            .build()
62            .expect("Failed to build HTTP client");
63
64        Self::with_http_client(base_url, model, http_client)
65    }
66
67    /// Create localhost client with insecure HTTPS on custom port
68    pub fn localhost_insecure(port: u16) -> Self {
69        Self::insecure(format!("https://localhost:{}", port), "llama")
70    }
71}
72
73#[async_trait]
74impl ChatClient for LlamaClient {
75    async fn chat(&self, request: ChatRequest) -> LlmResult<ChatResponse> {
76        let url = format!("{}/v1/chat/completions", self.base_url);
77
78        // Build llama.cpp-compatible request
79        let llama_request = LlamaChatRequest {
80            model: self.model.clone(),
81            messages: request.messages,
82            temperature: request.temperature,
83            max_tokens: request.max_tokens,
84            top_p: request.top_p,
85        };
86
87        // Send request
88        let response = self
89            .http_client
90            .post(&url)
91            .header("Content-Type", "application/json")
92            .json(&llama_request)
93            .send()
94            .await
95            .map_err(|e| LlmError::NetworkError(e.to_string()))?;
96
97        // Check status
98        let status = response.status();
99        if !status.is_success() {
100            let error_text = response.text().await.unwrap_or_default();
101            return Err(LlmError::ApiError(format!(
102                "Status {}: {}",
103                status, error_text
104            )));
105        }
106
107        // Parse response (same format as OpenAI)
108        let llama_response: LlamaChatResponse = response
109            .json()
110            .await
111            .map_err(|e| LlmError::ParseError(e.to_string()))?;
112
113        // Extract first choice
114        let choice = llama_response
115            .choices
116            .first()
117            .ok_or_else(|| LlmError::ParseError("No choices in response".to_string()))?;
118
119        Ok(ChatResponse {
120            content: choice.message.content.clone(),
121            model: llama_response.model.unwrap_or_else(|| self.model.clone()),
122            usage: llama_response.usage.map(|u| super::super::types::Usage {
123                prompt_tokens: u.prompt_tokens,
124                completion_tokens: u.completion_tokens,
125                total_tokens: u.total_tokens,
126            }),
127            finish_reason: choice.finish_reason.clone(),
128        })
129    }
130
131    async fn chat_stream(&self, request: ChatRequest) -> LlmResult<TextStream> {
132        let url = format!("{}/v1/chat/completions", self.base_url);
133
134        // Build llama.cpp-compatible request with streaming enabled
135        let llama_request = LlamaChatRequest {
136            model: self.model.clone(),
137            messages: request.messages,
138            temperature: request.temperature,
139            max_tokens: request.max_tokens,
140            top_p: request.top_p,
141        };
142
143        // Send request with streaming
144        let response = self
145            .http_client
146            .post(&url)
147            .header("Content-Type", "application/json")
148            .header("Accept", "text/event-stream")
149            .json(&serde_json::json!({
150                "model": llama_request.model,
151                "messages": llama_request.messages,
152                "temperature": llama_request.temperature,
153                "max_tokens": llama_request.max_tokens,
154                "top_p": llama_request.top_p,
155                "stream": true,
156            }))
157            .send()
158            .await
159            .map_err(|e| LlmError::NetworkError(e.to_string()))?;
160
161        if !response.status().is_success() {
162            let status = response.status();
163            let error_text = response.text().await.unwrap_or_default();
164            return Err(LlmError::ApiError(format!(
165                "HTTP {}: {}",
166                status, error_text
167            )));
168        }
169
170        // Convert byte stream to text chunks
171        let stream = response.bytes_stream();
172        let text_stream = stream.map(|chunk_result| {
173            chunk_result
174                .map_err(|e| LlmError::NetworkError(e.to_string()))
175                .map(|bytes| {
176                    // Parse SSE format: "data: {...}\n\n"
177                    let text = String::from_utf8_lossy(&bytes);
178                    for line in text.lines() {
179                        if let Some(json_str) = line.strip_prefix("data: ") {
180                            if json_str.trim() == "[DONE]" {
181                                continue;
182                            }
183                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str)
184                            {
185                                if let Some(delta) = parsed
186                                    .get("choices")
187                                    .and_then(|c| c.get(0))
188                                    .and_then(|c| c.get("delta"))
189                                    .and_then(|d| d.get("content"))
190                                    .and_then(|c| c.as_str())
191                                {
192                                    return delta.to_string();
193                                }
194                            }
195                        }
196                    }
197                    String::new()
198                })
199        });
200
201        Ok(Box::pin(text_stream))
202    }
203
204    fn model(&self) -> &str {
205        &self.model
206    }
207
208    fn provider(&self) -> &str {
209        "llama.cpp"
210    }
211}
212
213// llama.cpp request/response types (OpenAI-compatible)
214
215#[derive(Debug, Serialize)]
216struct LlamaChatRequest {
217    model: String,
218    messages: Vec<super::super::types::ChatMessage>,
219
220    #[serde(skip_serializing_if = "Option::is_none")]
221    temperature: Option<f32>,
222
223    #[serde(skip_serializing_if = "Option::is_none")]
224    max_tokens: Option<u32>,
225
226    #[serde(skip_serializing_if = "Option::is_none")]
227    top_p: Option<f32>,
228}
229
230#[derive(Debug, Deserialize)]
231struct LlamaChatResponse {
232    model: Option<String>,
233    choices: Vec<Choice>,
234    usage: Option<UsageInfo>,
235}
236
237#[derive(Debug, Deserialize)]
238struct Choice {
239    message: Message,
240    finish_reason: Option<String>,
241}
242
243#[derive(Debug, Deserialize)]
244struct Message {
245    content: String,
246}
247
248#[derive(Debug, Deserialize)]
249struct UsageInfo {
250    prompt_tokens: u32,
251    completion_tokens: u32,
252    total_tokens: u32,
253}