agent_runtime/llm/provider/
llama.rs1use async_trait::async_trait;
2use futures::stream::StreamExt;
3use reqwest::Client as HttpClient;
4use serde::{Deserialize, Serialize};
5
6use super::super::{ChatClient, ChatRequest, ChatResponse, LlmError, LlmResult, TextStream};
7
8pub struct LlamaClient {
13 base_url: String,
14 model: String,
15 http_client: HttpClient,
16}
17
18impl LlamaClient {
19 pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
25 Self {
26 base_url: base_url.into(),
27 model: model.into(),
28 http_client: HttpClient::new(),
29 }
30 }
31
32 pub fn with_http_client(
35 base_url: impl Into<String>,
36 model: impl Into<String>,
37 http_client: HttpClient,
38 ) -> Self {
39 Self {
40 base_url: base_url.into(),
41 model: model.into(),
42 http_client,
43 }
44 }
45
46 pub fn localhost() -> Self {
48 Self::new("http://localhost:8080", "llama")
49 }
50
51 pub fn localhost_with_port(port: u16) -> Self {
53 Self::new(format!("http://localhost:{}", port), "llama")
54 }
55
56 pub fn insecure(base_url: impl Into<String>, model: impl Into<String>) -> Self {
59 let http_client = HttpClient::builder()
60 .danger_accept_invalid_certs(true)
61 .build()
62 .expect("Failed to build HTTP client");
63
64 Self::with_http_client(base_url, model, http_client)
65 }
66
67 pub fn localhost_insecure(port: u16) -> Self {
69 Self::insecure(format!("https://localhost:{}", port), "llama")
70 }
71}
72
73#[async_trait]
74impl ChatClient for LlamaClient {
75 async fn chat(&self, request: ChatRequest) -> LlmResult<ChatResponse> {
76 let url = format!("{}/v1/chat/completions", self.base_url);
77
78 let llama_request = LlamaChatRequest {
80 model: self.model.clone(),
81 messages: request.messages,
82 temperature: request.temperature,
83 max_tokens: request.max_tokens,
84 top_p: request.top_p,
85 };
86
87 let response = self
89 .http_client
90 .post(&url)
91 .header("Content-Type", "application/json")
92 .json(&llama_request)
93 .send()
94 .await
95 .map_err(|e| LlmError::NetworkError(e.to_string()))?;
96
97 let status = response.status();
99 if !status.is_success() {
100 let error_text = response.text().await.unwrap_or_default();
101 return Err(LlmError::ApiError(format!(
102 "Status {}: {}",
103 status, error_text
104 )));
105 }
106
107 let llama_response: LlamaChatResponse = response
109 .json()
110 .await
111 .map_err(|e| LlmError::ParseError(e.to_string()))?;
112
113 let choice = llama_response
115 .choices
116 .first()
117 .ok_or_else(|| LlmError::ParseError("No choices in response".to_string()))?;
118
119 Ok(ChatResponse {
120 content: choice.message.content.clone(),
121 model: llama_response.model.unwrap_or_else(|| self.model.clone()),
122 usage: llama_response.usage.map(|u| super::super::types::Usage {
123 prompt_tokens: u.prompt_tokens,
124 completion_tokens: u.completion_tokens,
125 total_tokens: u.total_tokens,
126 }),
127 finish_reason: choice.finish_reason.clone(),
128 })
129 }
130
131 async fn chat_stream(&self, request: ChatRequest) -> LlmResult<TextStream> {
132 let url = format!("{}/v1/chat/completions", self.base_url);
133
134 let llama_request = LlamaChatRequest {
136 model: self.model.clone(),
137 messages: request.messages,
138 temperature: request.temperature,
139 max_tokens: request.max_tokens,
140 top_p: request.top_p,
141 };
142
143 let response = self
145 .http_client
146 .post(&url)
147 .header("Content-Type", "application/json")
148 .header("Accept", "text/event-stream")
149 .json(&serde_json::json!({
150 "model": llama_request.model,
151 "messages": llama_request.messages,
152 "temperature": llama_request.temperature,
153 "max_tokens": llama_request.max_tokens,
154 "top_p": llama_request.top_p,
155 "stream": true,
156 }))
157 .send()
158 .await
159 .map_err(|e| LlmError::NetworkError(e.to_string()))?;
160
161 if !response.status().is_success() {
162 let status = response.status();
163 let error_text = response.text().await.unwrap_or_default();
164 return Err(LlmError::ApiError(format!(
165 "HTTP {}: {}",
166 status, error_text
167 )));
168 }
169
170 let stream = response.bytes_stream();
172 let text_stream = stream.map(|chunk_result| {
173 chunk_result
174 .map_err(|e| LlmError::NetworkError(e.to_string()))
175 .map(|bytes| {
176 let text = String::from_utf8_lossy(&bytes);
178 for line in text.lines() {
179 if let Some(json_str) = line.strip_prefix("data: ") {
180 if json_str.trim() == "[DONE]" {
181 continue;
182 }
183 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str)
184 {
185 if let Some(delta) = parsed
186 .get("choices")
187 .and_then(|c| c.get(0))
188 .and_then(|c| c.get("delta"))
189 .and_then(|d| d.get("content"))
190 .and_then(|c| c.as_str())
191 {
192 return delta.to_string();
193 }
194 }
195 }
196 }
197 String::new()
198 })
199 });
200
201 Ok(Box::pin(text_stream))
202 }
203
204 fn model(&self) -> &str {
205 &self.model
206 }
207
208 fn provider(&self) -> &str {
209 "llama.cpp"
210 }
211}
212
213#[derive(Debug, Serialize)]
216struct LlamaChatRequest {
217 model: String,
218 messages: Vec<super::super::types::ChatMessage>,
219
220 #[serde(skip_serializing_if = "Option::is_none")]
221 temperature: Option<f32>,
222
223 #[serde(skip_serializing_if = "Option::is_none")]
224 max_tokens: Option<u32>,
225
226 #[serde(skip_serializing_if = "Option::is_none")]
227 top_p: Option<f32>,
228}
229
230#[derive(Debug, Deserialize)]
231struct LlamaChatResponse {
232 model: Option<String>,
233 choices: Vec<Choice>,
234 usage: Option<UsageInfo>,
235}
236
237#[derive(Debug, Deserialize)]
238struct Choice {
239 message: Message,
240 finish_reason: Option<String>,
241}
242
243#[derive(Debug, Deserialize)]
244struct Message {
245 content: String,
246}
247
248#[derive(Debug, Deserialize)]
249struct UsageInfo {
250 prompt_tokens: u32,
251 completion_tokens: u32,
252 total_tokens: u32,
253}