1use crate::client::{
27 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::sse::{Event, GenericEventSource};
31use crate::http_client::{self, HttpClientExt};
32use crate::json_utils::empty_or_none;
33use crate::providers::openai::{self, StreamingToolCall};
34use crate::{
35 completion::{self, CompletionError, CompletionRequest},
36 embeddings::{self, EmbeddingError},
37 json_utils,
38};
39use async_stream::stream;
40use bytes::Bytes;
41use futures::StreamExt;
42use serde::{Deserialize, Serialize};
43use std::collections::HashMap;
44use tracing::{Level, info_span};
45use tracing_futures::Instrument;
46
47const LLAMAFILE_API_BASE_URL: &str = "http://localhost:8080";
51
52pub const LLAMA_CPP: &str = "LLaMA_CPP";
54
55#[derive(Debug, Default, Clone, Copy)]
56pub struct LlamafileExt;
57
58#[derive(Debug, Default, Clone, Copy)]
59pub struct LlamafileBuilder;
60
61impl Provider for LlamafileExt {
62 type Builder = LlamafileBuilder;
63 const VERIFY_PATH: &'static str = "v1/models";
64}
65
66impl<H> Capabilities<H> for LlamafileExt {
67 type Completion = Capable<CompletionModel<H>>;
68 type Embeddings = Capable<EmbeddingModel<H>>;
69 type Transcription = Nothing;
70 type ModelListing = Nothing;
71 #[cfg(feature = "image")]
72 type ImageGeneration = Nothing;
73 #[cfg(feature = "audio")]
74 type AudioGeneration = Nothing;
75}
76
77impl DebugExt for LlamafileExt {}
78
79impl ProviderBuilder for LlamafileBuilder {
80 type Extension<H>
81 = LlamafileExt
82 where
83 H: HttpClientExt;
84 type ApiKey = Nothing;
85
86 const BASE_URL: &'static str = LLAMAFILE_API_BASE_URL;
87
88 fn build<H>(
89 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
90 ) -> http_client::Result<Self::Extension<H>>
91 where
92 H: HttpClientExt,
93 {
94 Ok(LlamafileExt)
95 }
96}
97
98pub type Client<H = reqwest::Client> = client::Client<LlamafileExt, H>;
99pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<LlamafileBuilder, Nothing, H>;
100
101impl Client {
102 pub fn from_url(base_url: &str) -> Self {
105 Self::builder()
106 .api_key(Nothing)
107 .base_url(base_url)
108 .build()
109 .expect("Failed to build llamafile client")
110 }
111}
112
113impl ProviderClient for Client {
114 type Input = Nothing;
115
116 fn from_env() -> Self {
117 let api_base =
118 std::env::var("LLAMAFILE_API_BASE_URL").expect("LLAMAFILE_API_BASE_URL not set");
119 Self::from_url(&api_base)
120 }
121
122 fn from_val(_: Self::Input) -> Self {
123 Self::builder().api_key(Nothing).build().unwrap()
124 }
125}
126
127#[derive(Debug, Deserialize)]
132struct ApiErrorResponse {
133 message: String,
134}
135
136#[derive(Debug, Deserialize)]
137#[serde(untagged)]
138enum ApiResponse<T> {
139 Ok(T),
140 Err(ApiErrorResponse),
141}
142
143#[derive(Debug, Serialize, Deserialize)]
150struct LlamafileCompletionRequest {
151 model: String,
152 messages: Vec<openai::Message>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 temperature: Option<f64>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 max_tokens: Option<u64>,
157 #[serde(skip_serializing_if = "Vec::is_empty")]
158 tools: Vec<openai::ToolDefinition>,
159 #[serde(flatten, skip_serializing_if = "Option::is_none")]
160 additional_params: Option<serde_json::Value>,
161}
162
163impl TryFrom<(&str, CompletionRequest)> for LlamafileCompletionRequest {
164 type Error = CompletionError;
165
166 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
167 if req.output_schema.is_some() {
168 tracing::warn!("Structured outputs may not be supported by llamafile");
169 }
170 let model = req.model.clone().unwrap_or_else(|| model.to_string());
171
172 let mut full_history: Vec<openai::Message> = match &req.preamble {
174 Some(preamble) => vec![openai::Message::system(preamble)],
175 None => vec![],
176 };
177
178 if let Some(docs) = req.normalized_documents() {
179 let docs: Vec<openai::Message> = docs.try_into()?;
180 full_history.extend(docs);
181 }
182
183 let chat_history: Vec<openai::Message> = req
184 .chat_history
185 .clone()
186 .into_iter()
187 .map(|msg| msg.try_into())
188 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
189 .into_iter()
190 .flatten()
191 .collect();
192
193 full_history.extend(chat_history);
194
195 Ok(Self {
196 model,
197 messages: full_history,
198 temperature: req.temperature,
199 max_tokens: req.max_tokens,
200 tools: req
201 .tools
202 .into_iter()
203 .map(openai::ToolDefinition::from)
204 .collect(),
205 additional_params: req.additional_params,
206 })
207 }
208}
209
210#[derive(Clone)]
216pub struct CompletionModel<T = reqwest::Client> {
217 client: Client<T>,
218 pub model: String,
220}
221
222impl<T> CompletionModel<T> {
223 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
225 Self {
226 client,
227 model: model.into(),
228 }
229 }
230}
231
232impl<T> completion::CompletionModel for CompletionModel<T>
233where
234 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
235{
236 type Response = openai::CompletionResponse;
237 type StreamingResponse = StreamingCompletionResponse;
238 type Client = Client<T>;
239
240 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
241 Self::new(client.clone(), model)
242 }
243
244 async fn completion(
245 &self,
246 completion_request: CompletionRequest,
247 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
248 let span = if tracing::Span::current().is_disabled() {
249 info_span!(
250 target: "rig::completions",
251 "chat",
252 gen_ai.operation.name = "chat",
253 gen_ai.provider.name = "llamafile",
254 gen_ai.request.model = self.model,
255 gen_ai.system_instructions = completion_request.preamble,
256 gen_ai.response.id = tracing::field::Empty,
257 gen_ai.response.model = tracing::field::Empty,
258 gen_ai.usage.output_tokens = tracing::field::Empty,
259 gen_ai.usage.input_tokens = tracing::field::Empty,
260 )
261 } else {
262 tracing::Span::current()
263 };
264
265 let request =
266 LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
267
268 if tracing::enabled!(Level::TRACE) {
269 tracing::trace!(target: "rig::completions",
270 "Llamafile completion request: {}",
271 serde_json::to_string_pretty(&request)?
272 );
273 }
274
275 let body = serde_json::to_vec(&request)?;
276 let req = self
277 .client
278 .post("v1/chat/completions")?
279 .body(body)
280 .map_err(|e| CompletionError::HttpError(e.into()))?;
281
282 async move {
283 let response = self.client.send::<_, Bytes>(req).await?;
284 let status = response.status();
285 let response_body = response.into_body().into_future().await?.to_vec();
286
287 if status.is_success() {
288 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
289 &response_body,
290 )? {
291 ApiResponse::Ok(response) => {
292 let span = tracing::Span::current();
293 span.record("gen_ai.response.id", response.id.clone());
294 span.record("gen_ai.response.model_name", response.model.clone());
295 if let Some(ref usage) = response.usage {
296 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
297 span.record(
298 "gen_ai.usage.output_tokens",
299 usage.total_tokens - usage.prompt_tokens,
300 );
301 }
302
303 if tracing::enabled!(Level::TRACE) {
304 tracing::trace!(target: "rig::completions",
305 "Llamafile completion response: {}",
306 serde_json::to_string_pretty(&response)?
307 );
308 }
309
310 response.try_into()
311 }
312 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
313 }
314 } else {
315 Err(CompletionError::ProviderError(
316 String::from_utf8_lossy(&response_body).to_string(),
317 ))
318 }
319 }
320 .instrument(span)
321 .await
322 }
323
324 async fn stream(
325 &self,
326 completion_request: CompletionRequest,
327 ) -> Result<
328 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
329 CompletionError,
330 > {
331 let span = if tracing::Span::current().is_disabled() {
332 info_span!(
333 target: "rig::completions",
334 "chat_streaming",
335 gen_ai.operation.name = "chat_streaming",
336 gen_ai.provider.name = "llamafile",
337 gen_ai.request.model = self.model,
338 gen_ai.system_instructions = completion_request.preamble,
339 gen_ai.response.id = tracing::field::Empty,
340 gen_ai.response.model = tracing::field::Empty,
341 gen_ai.usage.output_tokens = tracing::field::Empty,
342 gen_ai.usage.input_tokens = tracing::field::Empty,
343 )
344 } else {
345 tracing::Span::current()
346 };
347
348 let mut request =
349 LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
350
351 let params = json_utils::merge(
352 request.additional_params.unwrap_or(serde_json::json!({})),
353 serde_json::json!({"stream": true}),
354 );
355 request.additional_params = Some(params);
356
357 if tracing::enabled!(Level::TRACE) {
358 tracing::trace!(target: "rig::completions",
359 "Llamafile streaming completion request: {}",
360 serde_json::to_string_pretty(&request)?
361 );
362 }
363
364 let body = serde_json::to_vec(&request)?;
365 let req = self
366 .client
367 .post("v1/chat/completions")?
368 .body(body)
369 .map_err(|e| CompletionError::HttpError(e.into()))?;
370
371 send_streaming_request(self.client.clone(), req, span).await
372 }
373}
374
375#[derive(Deserialize, Debug)]
380struct StreamingDelta {
381 #[serde(default)]
382 content: Option<String>,
383 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
384 tool_calls: Vec<StreamingToolCall>,
385}
386
387#[derive(Deserialize, Debug)]
388struct StreamingChoice {
389 delta: StreamingDelta,
390}
391
392#[derive(Deserialize, Debug)]
393struct StreamingCompletionChunk {
394 choices: Vec<StreamingChoice>,
395 usage: Option<openai::Usage>,
396}
397
398#[derive(Clone, Deserialize, Serialize, Debug)]
400pub struct StreamingCompletionResponse {
401 pub usage: openai::Usage,
403}
404
405impl GetTokenUsage for StreamingCompletionResponse {
406 fn token_usage(&self) -> Option<crate::completion::Usage> {
407 let mut usage = crate::completion::Usage::new();
408 usage.input_tokens = self.usage.prompt_tokens as u64;
409 usage.total_tokens = self.usage.total_tokens as u64;
410 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
411 Some(usage)
412 }
413}
414
415async fn send_streaming_request<T>(
416 client: T,
417 req: http::Request<Vec<u8>>,
418 span: tracing::Span,
419) -> Result<
420 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
421 CompletionError,
422>
423where
424 T: HttpClientExt + Clone + 'static,
425{
426 let mut event_source = GenericEventSource::new(client, req);
427
428 let stream = stream! {
429 let span = tracing::Span::current();
430 let mut final_usage = openai::Usage {
431 prompt_tokens: 0,
432 total_tokens: 0,
433 prompt_tokens_details: None,
434 };
435 let mut text_response = String::new();
436 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
437
438 while let Some(event_result) = event_source.next().await {
439 match event_result {
440 Ok(Event::Open) => {
441 tracing::trace!("SSE connection opened");
442 continue;
443 }
444 Ok(Event::Message(message)) => {
445 let data_str = message.data.trim();
446 if data_str.is_empty() || data_str == "[DONE]" {
447 continue;
448 }
449
450 let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
451 let Ok(data) = parsed else {
452 let err = parsed.unwrap_err();
453 tracing::debug!("Couldn't parse SSE payload: {:?}", err);
454 continue;
455 };
456
457 if let Some(choice) = data.choices.first() {
458 let delta = &choice.delta;
459
460 for tool_call in &delta.tool_calls {
462 let function = &tool_call.function;
463
464 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
466 && empty_or_none(&function.arguments)
467 {
468 let id = tool_call.id.clone().unwrap_or_default();
469 let name = function.name.clone().unwrap();
470 calls.insert(tool_call.index, (id, name, String::new()));
471 }
472 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
474 && let Some(arguments) = &function.arguments
475 && !arguments.is_empty()
476 {
477 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
478 let combined = format!("{}{}", existing_args, arguments);
479 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
480 }
481 }
482 else {
484 let id = tool_call.id.clone().unwrap_or_default();
485 let name = function.name.clone().unwrap_or_default();
486 let arguments_str = function.arguments.clone().unwrap_or_default();
487
488 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
489 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
490 continue;
491 };
492
493 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
494 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
495 ));
496 }
497 }
498
499 if let Some(content) = &delta.content {
501 text_response += content;
502 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
503 }
504 }
505
506 if let Some(usage) = data.usage {
507 final_usage = usage;
508 }
509 }
510 Err(crate::http_client::Error::StreamEnded) => break,
511 Err(err) => {
512 tracing::error!(?err, "SSE error");
513 yield Err(CompletionError::ResponseError(err.to_string()));
514 break;
515 }
516 }
517 }
518
519 event_source.close();
520
521 for (_, (id, name, arguments)) in calls {
523 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
524 continue;
525 };
526 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
527 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
528 ));
529 }
530
531 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
532 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
533
534 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
535 StreamingCompletionResponse { usage: final_usage }
536 ));
537 }.instrument(span);
538
539 Ok(crate::streaming::StreamingCompletionResponse::stream(
540 Box::pin(stream),
541 ))
542}
543
544#[derive(Clone)]
552pub struct EmbeddingModel<T = reqwest::Client> {
553 client: Client<T>,
554 pub model: String,
556 ndims: usize,
557}
558
559impl<T> EmbeddingModel<T> {
560 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
562 Self {
563 client,
564 model: model.into(),
565 ndims,
566 }
567 }
568}
569
570impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
571where
572 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
573{
574 const MAX_DOCUMENTS: usize = 1024;
575
576 type Client = Client<T>;
577
578 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
579 Self::new(client.clone(), model, ndims.unwrap_or_default())
580 }
581
582 fn ndims(&self) -> usize {
583 self.ndims
584 }
585
586 async fn embed_texts(
587 &self,
588 documents: impl IntoIterator<Item = String>,
589 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
590 let documents = documents.into_iter().collect::<Vec<_>>();
591
592 let body = serde_json::json!({
593 "model": self.model,
594 "input": documents,
595 });
596
597 let body = serde_json::to_vec(&body)?;
598
599 let req = self
600 .client
601 .post("v1/embeddings")?
602 .body(body)
603 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
604
605 let response = self.client.send(req).await?;
606
607 if response.status().is_success() {
608 let body: Vec<u8> = response.into_body().await?;
609 let body: ApiResponse<openai::EmbeddingResponse> = serde_json::from_slice(&body)?;
610
611 match body {
612 ApiResponse::Ok(response) => {
613 tracing::info!(target: "rig",
614 "Llamafile embedding token usage: {:?}",
615 response.usage
616 );
617
618 if response.data.len() != documents.len() {
619 return Err(EmbeddingError::ResponseError(
620 "Response data length does not match input length".into(),
621 ));
622 }
623
624 Ok(response
625 .data
626 .into_iter()
627 .zip(documents.into_iter())
628 .map(|(embedding, document)| embeddings::Embedding {
629 document,
630 vec: embedding
631 .embedding
632 .into_iter()
633 .filter_map(|n| n.as_f64())
634 .collect(),
635 })
636 .collect())
637 }
638 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
639 }
640 } else {
641 let text = http_client::text(response).await?;
642 Err(EmbeddingError::ProviderError(text))
643 }
644 }
645}
646
647#[cfg(test)]
651mod tests {
652 use super::*;
653 use crate::client::Nothing;
654
655 #[test]
656 fn test_client_initialization() {
657 let _client =
658 crate::providers::llamafile::Client::new(Nothing).expect("Client::new() failed");
659 let _client_from_builder = crate::providers::llamafile::Client::builder()
660 .api_key(Nothing)
661 .build()
662 .expect("Client::builder() failed");
663 }
664
665 #[test]
666 fn test_client_from_url() {
667 let _client = crate::providers::llamafile::Client::from_url("http://localhost:8080");
668 }
669
670 #[test]
671 fn test_completion_request_conversion() {
672 use crate::OneOrMany;
673 use crate::completion::Message as CompletionMessage;
674 use crate::message::{Text, UserContent};
675
676 let completion_request = CompletionRequest {
677 model: None,
678 preamble: Some("You are a helpful assistant.".to_string()),
679 chat_history: OneOrMany::one(CompletionMessage::User {
680 content: OneOrMany::one(UserContent::Text(Text {
681 text: "Hello!".to_string(),
682 })),
683 }),
684 documents: vec![],
685 tools: vec![],
686 temperature: Some(0.7),
687 max_tokens: Some(256),
688 tool_choice: None,
689 additional_params: None,
690 output_schema: None,
691 };
692
693 let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request))
694 .expect("Failed to create request");
695
696 assert_eq!(request.model, LLAMA_CPP);
697 assert_eq!(request.messages.len(), 2); assert_eq!(request.temperature, Some(0.7));
699 assert_eq!(request.max_tokens, Some(256));
700 }
701}