agent_runtime/llm/provider/
llama.rs1use async_trait::async_trait;
2use futures::stream::StreamExt;
3use reqwest::Client as HttpClient;
4use serde::{Deserialize, Serialize};
5use tokio::sync::mpsc;
6
7use super::super::{ChatClient, ChatRequest, ChatResponse, LlmError, LlmResult};
8
9pub struct LlamaClient {
14 base_url: String,
15 model: String,
16 http_client: HttpClient,
17}
18
19impl LlamaClient {
20 pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
26 Self {
27 base_url: base_url.into(),
28 model: model.into(),
29 http_client: HttpClient::new(),
30 }
31 }
32
33 pub fn with_http_client(
36 base_url: impl Into<String>,
37 model: impl Into<String>,
38 http_client: HttpClient,
39 ) -> Self {
40 Self {
41 base_url: base_url.into(),
42 model: model.into(),
43 http_client,
44 }
45 }
46
47 pub fn localhost() -> Self {
49 Self::new("http://localhost:8080", "llama")
50 }
51
52 pub fn localhost_with_port(port: u16) -> Self {
54 Self::new(format!("http://localhost:{}", port), "llama")
55 }
56
57 pub fn insecure(base_url: impl Into<String>, model: impl Into<String>) -> Self {
60 let http_client = HttpClient::builder()
61 .danger_accept_invalid_certs(true)
62 .build()
63 .expect("Failed to build HTTP client");
64
65 Self::with_http_client(base_url, model, http_client)
66 }
67
68 pub fn localhost_insecure(port: u16) -> Self {
70 Self::insecure(format!("https://localhost:{}", port), "llama")
71 }
72
73 pub fn model(&self) -> &str {
75 &self.model
76 }
77
78 pub fn provider(&self) -> &str {
80 "llama.cpp"
81 }
82}
83
84#[async_trait]
85impl ChatClient for LlamaClient {
86 async fn chat(&self, request: ChatRequest) -> LlmResult<ChatResponse> {
87 let url = format!("{}/chat/completions", self.base_url);
88
89 let llama_request = LlamaChatRequest {
91 model: self.model.clone(),
92 messages: request.messages,
93 temperature: request.temperature,
94 max_tokens: request.max_tokens,
95 top_p: request.top_p,
96 tools: request.tools,
97 };
98
99 let response = self
101 .http_client
102 .post(&url)
103 .header("Content-Type", "application/json")
104 .json(&llama_request)
105 .send()
106 .await
107 .map_err(|e| LlmError::NetworkError(e.to_string()))?;
108
109 let status = response.status();
111 if !status.is_success() {
112 let error_text = response.text().await.unwrap_or_default();
113 return Err(LlmError::ApiError(format!(
114 "Status {}: {}",
115 status, error_text
116 )));
117 }
118
119 let llama_response: LlamaChatResponse = response
121 .json()
122 .await
123 .map_err(|e| LlmError::ParseError(e.to_string()))?;
124
125 let choice = llama_response
127 .choices
128 .first()
129 .ok_or_else(|| LlmError::ParseError("No choices in response".to_string()))?;
130
131 let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
133 calls
134 .iter()
135 .map(|tc| super::super::types::ToolCall {
136 id: tc.id.clone(),
137 r#type: tc.r#type.clone(),
138 function: super::super::types::FunctionCall {
139 name: tc.function.name.clone(),
140 arguments: tc.function.arguments.clone(),
141 },
142 })
143 .collect()
144 });
145
146 Ok(ChatResponse {
147 content: choice.message.content.clone(),
148 model: llama_response.model.unwrap_or_else(|| self.model.clone()),
149 usage: llama_response.usage.map(|u| super::super::types::Usage {
150 prompt_tokens: u.prompt_tokens,
151 completion_tokens: u.completion_tokens,
152 total_tokens: u.total_tokens,
153 }),
154 finish_reason: choice.finish_reason.clone(),
155 tool_calls,
156 })
157 }
158
159 async fn chat_stream(
160 &self,
161 request: ChatRequest,
162 tx: mpsc::Sender<String>,
163 ) -> LlmResult<ChatResponse> {
164 let url = format!("{}/v1/chat/completions", self.base_url);
165
166 let llama_request = LlamaChatRequest {
168 model: self.model.clone(),
169 messages: request.messages.clone(),
170 temperature: request.temperature,
171 max_tokens: request.max_tokens,
172 top_p: request.top_p,
173 tools: request.tools.clone(),
174 };
175
176 let response = self
178 .http_client
179 .post(&url)
180 .header("Content-Type", "application/json")
181 .header("Accept", "text/event-stream")
182 .json(&serde_json::json!({
183 "model": llama_request.model,
184 "messages": llama_request.messages,
185 "temperature": llama_request.temperature,
186 "max_tokens": llama_request.max_tokens,
187 "top_p": llama_request.top_p,
188 "tools": llama_request.tools,
189 "stream": true,
190 }))
191 .send()
192 .await
193 .map_err(|e| LlmError::NetworkError(e.to_string()))?;
194
195 if !response.status().is_success() {
196 let status = response.status();
197 let error_text = response.text().await.unwrap_or_default();
198 return Err(LlmError::ApiError(format!(
199 "HTTP {}: {}",
200 status, error_text
201 )));
202 }
203
204 let mut stream = response.bytes_stream();
206 while let Some(chunk_result) = stream.next().await {
207 let bytes = chunk_result.map_err(|e| LlmError::NetworkError(e.to_string()))?;
208
209 let text = String::from_utf8_lossy(&bytes);
211 for line in text.lines() {
212 if let Some(json_str) = line.strip_prefix("data: ") {
213 if json_str.trim() == "[DONE]" {
214 continue;
215 }
216 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
217 if let Some(delta) = parsed
218 .get("choices")
219 .and_then(|c| c.get(0))
220 .and_then(|c| c.get("delta"))
221 .and_then(|d| d.get("content"))
222 .and_then(|c| c.as_str())
223 {
224 let _ = tx.send(delta.to_string()).await;
225 }
226 }
227 }
228 }
229 }
230
231 self.chat(request).await
233 }
234}
235
236#[derive(Debug, Serialize)]
239struct LlamaChatRequest {
240 model: String,
241 messages: Vec<super::super::types::ChatMessage>,
242
243 #[serde(skip_serializing_if = "Option::is_none")]
244 temperature: Option<f32>,
245
246 #[serde(skip_serializing_if = "Option::is_none")]
247 max_tokens: Option<u32>,
248
249 #[serde(skip_serializing_if = "Option::is_none")]
250 top_p: Option<f32>,
251
252 #[serde(skip_serializing_if = "Option::is_none")]
253 tools: Option<Vec<serde_json::Value>>,
254}
255
256#[derive(Debug, Deserialize)]
257struct LlamaChatResponse {
258 model: Option<String>,
259 choices: Vec<Choice>,
260 usage: Option<UsageInfo>,
261}
262
263#[derive(Debug, Deserialize)]
264struct Choice {
265 message: Message,
266 finish_reason: Option<String>,
267}
268
269#[derive(Debug, Deserialize)]
270struct Message {
271 #[serde(default)]
272 content: String,
273
274 #[serde(skip_serializing_if = "Option::is_none")]
275 tool_calls: Option<Vec<LlamaToolCall>>,
276}
277
278#[derive(Debug, Deserialize)]
279struct LlamaToolCall {
280 id: String,
281 r#type: String,
282 function: LlamaFunction,
283}
284
285#[derive(Debug, Deserialize)]
286struct LlamaFunction {
287 name: String,
288 arguments: String,
289}
290
291#[derive(Debug, Deserialize)]
292struct UsageInfo {
293 prompt_tokens: u32,
294 completion_tokens: u32,
295 total_tokens: u32,
296}