1use std::collections::HashMap;
7
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10
11use crate::models::{Model, ModelConfig, StreamEventStream};
12use crate::types::content::{Message, Role, SystemContentBlock};
13use crate::types::errors::StrandsError;
14use crate::types::streaming::{StopReason, StreamEvent};
15use crate::types::tools::{ToolChoice, ToolSpec};
16
17#[derive(Debug, Clone)]
19pub struct LlamaCppConfig {
20 pub model_id: String,
22 pub base_url: String,
24 pub params: HashMap<String, serde_json::Value>,
26}
27
28impl Default for LlamaCppConfig {
29 fn default() -> Self {
30 Self {
31 model_id: "default".to_string(),
32 base_url: "http://localhost:8080".to_string(),
33 params: HashMap::new(),
34 }
35 }
36}
37
38impl LlamaCppConfig {
39 pub fn new(base_url: impl Into<String>) -> Self {
40 Self {
41 base_url: base_url.into(),
42 ..Default::default()
43 }
44 }
45
46 pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
47 self.model_id = model_id.into();
48 self
49 }
50
51 pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
52 self.params.insert(key.into(), value);
53 self
54 }
55
56 pub fn with_temperature(mut self, temperature: f32) -> Self {
57 self.params.insert("temperature".to_string(), serde_json::json!(temperature));
58 self
59 }
60
61 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
62 self.params.insert("max_tokens".to_string(), serde_json::json!(max_tokens));
63 self
64 }
65}
66
67#[derive(Debug, Serialize)]
69struct LlamaCppRequest {
70 model: String,
71 messages: Vec<LlamaCppMessage>,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 max_tokens: Option<u32>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 temperature: Option<f32>,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 tools: Option<Vec<LlamaCppTool>>,
78 stream: bool,
79 #[serde(flatten)]
80 extra: HashMap<String, serde_json::Value>,
81}
82
83#[derive(Debug, Serialize, Deserialize)]
84struct LlamaCppMessage {
85 role: String,
86 content: serde_json::Value,
87}
88
89#[derive(Debug, Serialize)]
90struct LlamaCppTool {
91 #[serde(rename = "type")]
92 tool_type: String,
93 function: LlamaCppFunction,
94}
95
96#[derive(Debug, Serialize)]
97struct LlamaCppFunction {
98 name: String,
99 description: String,
100 parameters: serde_json::Value,
101}
102
103pub struct LlamaCppModel {
105 config: ModelConfig,
106 llamacpp_config: LlamaCppConfig,
107 client: reqwest::Client,
108}
109
110impl LlamaCppModel {
111 pub fn new(config: LlamaCppConfig) -> Self {
112 let model_config = ModelConfig::new(&config.model_id);
113
114 Self {
115 config: model_config,
116 llamacpp_config: config,
117 client: reqwest::Client::new(),
118 }
119 }
120
121 fn convert_messages(&self, messages: &[Message], system_prompt: Option<&str>) -> Vec<LlamaCppMessage> {
122 let mut result = Vec::new();
123
124 if let Some(prompt) = system_prompt {
125 result.push(LlamaCppMessage {
126 role: "system".to_string(),
127 content: serde_json::json!(prompt),
128 });
129 }
130
131 for msg in messages {
132 let role = match msg.role {
133 Role::User => "user",
134 Role::Assistant => "assistant",
135 };
136
137 let content = msg.text_content();
138
139 result.push(LlamaCppMessage {
140 role: role.to_string(),
141 content: serde_json::json!(content),
142 });
143 }
144
145 result
146 }
147
148 fn convert_tools(&self, tool_specs: &[ToolSpec]) -> Vec<LlamaCppTool> {
149 tool_specs
150 .iter()
151 .map(|spec| LlamaCppTool {
152 tool_type: "function".to_string(),
153 function: LlamaCppFunction {
154 name: spec.name.clone(),
155 description: spec.description.clone(),
156 parameters: serde_json::to_value(&spec.input_schema).unwrap_or_default(),
157 },
158 })
159 .collect()
160 }
161}
162
163#[async_trait]
164impl Model for LlamaCppModel {
165 fn config(&self) -> &ModelConfig {
166 &self.config
167 }
168
169 fn update_config(&mut self, config: ModelConfig) {
170 self.config = config;
171 }
172
173 fn stream<'a>(
174 &'a self,
175 messages: &'a [Message],
176 tool_specs: Option<&'a [ToolSpec]>,
177 system_prompt: Option<&'a str>,
178 _tool_choice: Option<ToolChoice>,
179 _system_prompt_content: Option<&'a [SystemContentBlock]>,
180 ) -> StreamEventStream<'a> {
181 let messages = messages.to_vec();
182 let tool_specs = tool_specs.map(|t| t.to_vec());
183 let system_prompt = system_prompt.map(|s| s.to_string());
184
185 Box::pin(async_stream::stream! {
186 let llamacpp_messages = self.convert_messages(&messages, system_prompt.as_deref());
187 let tools = tool_specs.as_ref().map(|specs| self.convert_tools(specs));
188
189 let max_tokens = self.llamacpp_config.params
190 .get("max_tokens")
191 .and_then(|v| v.as_u64())
192 .map(|v| v as u32);
193
194 let temperature = self.llamacpp_config.params
195 .get("temperature")
196 .and_then(|v| v.as_f64())
197 .map(|v| v as f32);
198
199 let request = LlamaCppRequest {
200 model: self.config.model_id.clone(),
201 messages: llamacpp_messages,
202 max_tokens,
203 temperature,
204 tools,
205 stream: true,
206 extra: self.llamacpp_config.params.clone(),
207 };
208
209 let url = format!("{}/v1/chat/completions", self.llamacpp_config.base_url);
210
211 let response = match self.client
212 .post(&url)
213 .header("Content-Type", "application/json")
214 .json(&request)
215 .send()
216 .await
217 {
218 Ok(resp) => resp,
219 Err(e) => {
220 yield Err(StrandsError::NetworkError(e.to_string()));
221 return;
222 }
223 };
224
225 if !response.status().is_success() {
226 let status = response.status();
227 let body = response.text().await.unwrap_or_default();
228
229 if status.as_u16() == 429 {
230 yield Err(StrandsError::ModelThrottled {
231 message: "llama.cpp rate limit exceeded".into(),
232 });
233 } else {
234 yield Err(StrandsError::ModelError {
235 message: format!("llama.cpp API error {}: {}", status, body),
236 source: None,
237 });
238 }
239 return;
240 }
241
242 yield Ok(StreamEvent::message_start(crate::types::content::Role::Assistant));
243
244 let body = match response.text().await {
245 Ok(b) => b,
246 Err(e) => {
247 yield Err(StrandsError::NetworkError(e.to_string()));
248 return;
249 }
250 };
251
252 for line in body.lines() {
253 if line.starts_with("data: ") {
254 let data = &line[6..];
255 if data == "[DONE]" {
256 break;
257 }
258
259 if let Ok(chunk) = serde_json::from_str::<serde_json::Value>(data) {
260 if let Some(choices) = chunk.get("choices").and_then(|c| c.as_array()) {
261 for choice in choices {
262 if let Some(delta) = choice.get("delta") {
263 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
264 yield Ok(StreamEvent::text_delta(0, content));
265 }
266 }
267 }
268 }
269 }
270 }
271 }
272
273 yield Ok(StreamEvent::message_stop(StopReason::EndTurn));
274 })
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn test_llamacpp_config() {
284 let config = LlamaCppConfig::new("http://localhost:8080")
285 .with_model_id("my-model")
286 .with_temperature(0.7);
287
288 assert_eq!(config.base_url, "http://localhost:8080");
289 assert_eq!(config.model_id, "my-model");
290 assert!(config.params.contains_key("temperature"));
291 }
292}
293