1use crate::client::{
27 self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::{self, HttpClientExt};
31use crate::providers::internal::openai_chat_completions_compatible::{
32 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
33};
34use crate::providers::openai::{self, StreamingToolCall};
35use crate::{
36 completion::{self, CompletionError, CompletionRequest},
37 embeddings::{self, EmbeddingError},
38 json_utils,
39};
40use bytes::Bytes;
41use serde::{Deserialize, Serialize};
42use serde_json::{Map, Value};
43use tracing::{Level, info_span};
44use tracing_futures::Instrument;
45
46const LLAMAFILE_API_BASE_URL: &str = "http://localhost:8080";
50
51pub const LLAMA_CPP: &str = "LLaMA_CPP";
53
54#[derive(Debug, Default, Clone, Copy)]
55pub struct LlamafileExt;
56
57#[derive(Debug, Default, Clone, Copy)]
58pub struct LlamafileBuilder;
59
60impl Provider for LlamafileExt {
61 type Builder = LlamafileBuilder;
62 const VERIFY_PATH: &'static str = "v1/models";
63}
64
65impl<H> Capabilities<H> for LlamafileExt {
66 type Completion = Capable<CompletionModel<H>>;
67 type Embeddings = Capable<EmbeddingModel<H>>;
68 type Transcription = Nothing;
69 type ModelListing = Nothing;
70 #[cfg(feature = "image")]
71 type ImageGeneration = Nothing;
72 #[cfg(feature = "audio")]
73 type AudioGeneration = Nothing;
74 type Rerank = Nothing;
75}
76
77impl DebugExt for LlamafileExt {}
78
79impl ProviderBuilder for LlamafileBuilder {
80 type Extension<H>
81 = LlamafileExt
82 where
83 H: HttpClientExt;
84 type ApiKey = Nothing;
85
86 const BASE_URL: &'static str = LLAMAFILE_API_BASE_URL;
87
88 fn build<H>(
89 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
90 ) -> http_client::Result<Self::Extension<H>>
91 where
92 H: HttpClientExt,
93 {
94 Ok(LlamafileExt)
95 }
96}
97
98pub type Client<H = reqwest::Client> = client::Client<LlamafileExt, H>;
99pub type ClientBuilder<H = crate::markers::Missing> =
100 client::ClientBuilder<LlamafileBuilder, Nothing, H>;
101
102impl Client {
103 pub fn from_url(base_url: &str) -> crate::client::ProviderClientResult<Self> {
106 Self::builder()
107 .api_key(Nothing)
108 .base_url(base_url)
109 .build()
110 .map_err(Into::into)
111 }
112}
113
114impl ProviderClient for Client {
115 type Input = Nothing;
116 type Error = crate::client::ProviderClientError;
117
118 fn from_env() -> Result<Self, Self::Error> {
119 let api_base = crate::client::required_env_var("LLAMAFILE_API_BASE_URL")?;
120 Self::from_url(&api_base)
121 }
122
123 fn from_val(_: Self::Input) -> Result<Self, Self::Error> {
124 Self::builder().api_key(Nothing).build().map_err(Into::into)
125 }
126}
127
128#[derive(Debug, Deserialize)]
133struct ApiErrorResponse {
134 message: String,
135}
136
137#[derive(Debug, Deserialize)]
138#[serde(untagged)]
139enum ApiResponse<T> {
140 Ok(T),
141 Err(ApiErrorResponse),
142}
143
144#[derive(Debug, Serialize)]
151struct LlamafileCompletionRequest {
152 model: String,
153 messages: Vec<Value>,
154 #[serde(skip_serializing_if = "Option::is_none")]
155 temperature: Option<f64>,
156 #[serde(skip_serializing_if = "Option::is_none")]
157 max_tokens: Option<u64>,
158 #[serde(skip_serializing_if = "Vec::is_empty")]
159 tools: Vec<openai::ToolDefinition>,
160 #[serde(flatten, skip_serializing_if = "Option::is_none")]
161 additional_params: Option<serde_json::Value>,
162}
163
164fn join_text_segments<I>(segments: I) -> String
165where
166 I: IntoIterator<Item = String>,
167{
168 let segments = segments
169 .into_iter()
170 .filter(|segment| !segment.is_empty())
171 .collect::<Vec<_>>();
172
173 if segments.is_empty() {
174 String::new()
175 } else {
176 segments.join("\n\n")
177 }
178}
179
180fn flatten_system_content(content: &crate::OneOrMany<openai::SystemContent>) -> String {
181 join_text_segments(content.iter().map(|item| item.text.clone()))
182}
183
184fn flatten_user_content(content: &crate::OneOrMany<openai::UserContent>) -> Option<String> {
185 content
186 .iter()
187 .map(|item| match item {
188 openai::UserContent::Text { text } => Some(text.clone()),
189 _ => None,
190 })
191 .collect::<Option<Vec<_>>>()
192 .map(join_text_segments)
193}
194
195fn flatten_assistant_content(content: &[openai::AssistantContent]) -> String {
196 join_text_segments(content.iter().map(|item| match item {
197 openai::AssistantContent::Text { text } => text.clone(),
198 openai::AssistantContent::Refusal { refusal } => refusal.clone(),
199 }))
200}
201
202fn optional_value<T>(value: Option<T>) -> Result<Option<Value>, CompletionError>
203where
204 T: Serialize,
205{
206 value
207 .map(serde_json::to_value)
208 .transpose()
209 .map_err(Into::into)
210}
211
212fn message_content_value<T>(
213 flattened: Option<String>,
214 original: &T,
215) -> Result<Value, CompletionError>
216where
217 T: Serialize,
218{
219 match flattened {
220 Some(text) => Ok(Value::String(text)),
221 None => Ok(serde_json::to_value(original)?),
222 }
223}
224
225fn llamafile_message_value(message: openai::Message) -> Result<Value, CompletionError> {
226 match message {
227 openai::Message::System { content, name } => {
228 let mut object = Map::new();
229 object.insert("role".into(), Value::String("system".into()));
230 object.insert(
231 "content".into(),
232 Value::String(flatten_system_content(&content)),
233 );
234 if let Some(name) = name {
235 object.insert("name".into(), Value::String(name));
236 }
237 Ok(Value::Object(object))
238 }
239 openai::Message::User { content, name } => {
240 let mut object = Map::new();
241 object.insert("role".into(), Value::String("user".into()));
242 object.insert(
243 "content".into(),
244 message_content_value(flatten_user_content(&content), &content)?,
245 );
246 if let Some(name) = name {
247 object.insert("name".into(), Value::String(name));
248 }
249 Ok(Value::Object(object))
250 }
251 openai::Message::Assistant {
252 content,
253 refusal,
254 reasoning: _,
255 audio,
256 name,
257 tool_calls,
258 } => {
259 let mut object = Map::new();
260 object.insert("role".into(), Value::String("assistant".into()));
261 object.insert(
262 "content".into(),
263 Value::String(flatten_assistant_content(&content)),
264 );
265 if let Some(refusal) = refusal {
266 object.insert("refusal".into(), Value::String(refusal));
267 }
268 if let Some(audio) = optional_value(audio)? {
269 object.insert("audio".into(), audio);
270 }
271 if let Some(name) = name {
272 object.insert("name".into(), Value::String(name));
273 }
274 if !tool_calls.is_empty() {
275 object.insert("tool_calls".into(), serde_json::to_value(tool_calls)?);
276 }
277 Ok(Value::Object(object))
278 }
279 openai::Message::ToolResult {
280 tool_call_id,
281 content,
282 } => {
283 let mut object = Map::new();
284 object.insert("role".into(), Value::String("tool".into()));
285 object.insert("tool_call_id".into(), Value::String(tool_call_id));
286 object.insert("content".into(), Value::String(content.as_text()));
287 Ok(Value::Object(object))
288 }
289 }
290}
291
292impl TryFrom<(&str, CompletionRequest)> for LlamafileCompletionRequest {
293 type Error = CompletionError;
294
295 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
296 let chat_history = req.chat_history_with_documents();
297 if req.output_schema.is_some() {
298 tracing::warn!("Structured outputs may not be supported by llamafile");
299 }
300 let model = req.model.clone().unwrap_or_else(|| model.to_string());
301
302 let mut full_history: Vec<openai::Message> = match &req.preamble {
304 Some(preamble) => vec![openai::Message::system(preamble)],
305 None => vec![],
306 };
307
308 let chat_history: Vec<openai::Message> = chat_history
309 .into_iter()
310 .map(|msg| msg.try_into())
311 .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
312 .into_iter()
313 .flatten()
314 .collect();
315
316 full_history.extend(chat_history);
317
318 Ok(Self {
319 model,
320 messages: full_history
321 .into_iter()
322 .map(llamafile_message_value)
323 .collect::<Result<Vec<_>, _>>()?,
324 temperature: req.temperature,
325 max_tokens: req.max_tokens,
326 tools: req
327 .tools
328 .into_iter()
329 .map(openai::ToolDefinition::from)
330 .collect(),
331 additional_params: req.additional_params,
332 })
333 }
334}
335
336#[derive(Clone)]
342pub struct CompletionModel<T = reqwest::Client> {
343 client: Client<T>,
344 pub model: String,
346}
347
348impl<T> CompletionModel<T> {
349 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
351 Self {
352 client,
353 model: model.into(),
354 }
355 }
356}
357
358impl<T> completion::CompletionModel for CompletionModel<T>
359where
360 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
361{
362 type Response = openai::CompletionResponse;
363 type StreamingResponse = StreamingCompletionResponse;
364 type Client = Client<T>;
365
366 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
367 Self::new(client.clone(), model)
368 }
369
370 async fn completion(
371 &self,
372 completion_request: CompletionRequest,
373 ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
374 let span = if tracing::Span::current().is_disabled() {
375 info_span!(
376 target: "rig::completions",
377 "chat",
378 gen_ai.operation.name = "chat",
379 gen_ai.provider.name = "llamafile",
380 gen_ai.request.model = self.model,
381 gen_ai.system_instructions = completion_request.preamble,
382 gen_ai.response.id = tracing::field::Empty,
383 gen_ai.response.model = tracing::field::Empty,
384 gen_ai.usage.output_tokens = tracing::field::Empty,
385 gen_ai.usage.input_tokens = tracing::field::Empty,
386 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
387 )
388 } else {
389 tracing::Span::current()
390 };
391
392 let request =
393 LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
394
395 if tracing::enabled!(Level::TRACE) {
396 tracing::trace!(target: "rig::completions",
397 "Llamafile completion request: {}",
398 serde_json::to_string_pretty(&request)?
399 );
400 }
401
402 let body = serde_json::to_vec(&request)?;
403 let req = self
404 .client
405 .post("v1/chat/completions")?
406 .body(body)
407 .map_err(|e| CompletionError::HttpError(e.into()))?;
408
409 async move {
410 let response = self.client.send::<_, Bytes>(req).await?;
411 let status = response.status();
412 let response_body = response.into_body().into_future().await?.to_vec();
413
414 if status.is_success() {
415 match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
416 &response_body,
417 )? {
418 ApiResponse::Ok(response) => {
419 let span = tracing::Span::current();
420 span.record("gen_ai.response.id", response.id.clone());
421 span.record("gen_ai.response.model", response.model.clone());
422 if let Some(ref usage) = response.usage {
423 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
424 span.record(
425 "gen_ai.usage.output_tokens",
426 usage.total_tokens - usage.prompt_tokens,
427 );
428 }
429
430 if tracing::enabled!(Level::TRACE) {
431 tracing::trace!(target: "rig::completions",
432 "Llamafile completion response: {}",
433 serde_json::to_string_pretty(&response)?
434 );
435 }
436
437 response.try_into()
438 }
439 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
440 }
441 } else {
442 Err(CompletionError::ProviderError(
443 String::from_utf8_lossy(&response_body).to_string(),
444 ))
445 }
446 }
447 .instrument(span)
448 .await
449 }
450
451 async fn stream(
452 &self,
453 completion_request: CompletionRequest,
454 ) -> Result<
455 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
456 CompletionError,
457 > {
458 let span = if tracing::Span::current().is_disabled() {
459 info_span!(
460 target: "rig::completions",
461 "chat_streaming",
462 gen_ai.operation.name = "chat_streaming",
463 gen_ai.provider.name = "llamafile",
464 gen_ai.request.model = self.model,
465 gen_ai.system_instructions = completion_request.preamble,
466 gen_ai.response.id = tracing::field::Empty,
467 gen_ai.response.model = tracing::field::Empty,
468 gen_ai.usage.output_tokens = tracing::field::Empty,
469 gen_ai.usage.input_tokens = tracing::field::Empty,
470 )
471 } else {
472 tracing::Span::current()
473 };
474
475 let mut request =
476 LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
477
478 let params = json_utils::merge(
479 request.additional_params.unwrap_or(serde_json::json!({})),
480 serde_json::json!({"stream": true}),
481 );
482 request.additional_params = Some(params);
483
484 if tracing::enabled!(Level::TRACE) {
485 tracing::trace!(target: "rig::completions",
486 "Llamafile streaming completion request: {}",
487 serde_json::to_string_pretty(&request)?
488 );
489 }
490
491 let body = serde_json::to_vec(&request)?;
492 let req = self
493 .client
494 .post("v1/chat/completions")?
495 .body(body)
496 .map_err(|e| CompletionError::HttpError(e.into()))?;
497
498 send_streaming_request(self.client.clone(), req, span).await
499 }
500}
501
502#[derive(Deserialize, Debug)]
507struct StreamingDelta {
508 #[serde(default)]
509 content: Option<String>,
510 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
511 tool_calls: Vec<StreamingToolCall>,
512}
513
514#[derive(Deserialize, Debug)]
515struct StreamingChoice {
516 delta: StreamingDelta,
517 #[serde(default)]
518 finish_reason: Option<openai::completion::streaming::FinishReason>,
519}
520
521#[derive(Deserialize, Debug)]
522struct StreamingCompletionChunk {
523 id: Option<String>,
524 model: Option<String>,
525 choices: Vec<StreamingChoice>,
526 usage: Option<openai::Usage>,
527}
528
529#[derive(Clone, Deserialize, Serialize, Debug)]
531pub struct StreamingCompletionResponse {
532 pub usage: openai::Usage,
534}
535
536impl GetTokenUsage for StreamingCompletionResponse {
537 fn token_usage(&self) -> crate::completion::Usage {
538 self.usage.token_usage()
539 }
540}
541
542#[derive(Clone, Copy)]
543struct LlamafileCompatibleProfile;
544
545impl CompatibleStreamProfile for LlamafileCompatibleProfile {
546 type Usage = openai::Usage;
547 type Detail = ();
548 type FinalResponse = StreamingCompletionResponse;
549
550 fn normalize_chunk(
551 &self,
552 data: &str,
553 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
554 let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
555 Ok(data) => data,
556 Err(error) => {
557 tracing::debug!(
558 ?error,
559 "Couldn't parse SSE payload as StreamingCompletionChunk"
560 );
561 return Ok(None);
562 }
563 };
564
565 Ok(Some(
566 openai_chat_completions_compatible::normalize_first_choice_chunk(
567 data.id,
568 data.model,
569 data.usage,
570 &data.choices,
571 |choice| CompatibleChoiceData {
572 finish_reason: if choice.finish_reason
573 == Some(openai::completion::streaming::FinishReason::ToolCalls)
574 {
575 CompatibleFinishReason::ToolCalls
576 } else {
577 CompatibleFinishReason::Other
578 },
579 text: choice.delta.content.clone(),
580 reasoning: None,
581 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
582 &choice.delta.tool_calls,
583 ),
584 details: Vec::new(),
585 },
586 ),
587 ))
588 }
589
590 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
591 StreamingCompletionResponse { usage }
592 }
593
594 fn uses_distinct_tool_call_eviction(&self) -> bool {
595 true
596 }
597
598 fn emits_complete_single_chunk_tool_calls(&self) -> bool {
599 true
600 }
601}
602
603async fn send_streaming_request<T>(
604 client: T,
605 req: http::Request<Vec<u8>>,
606 span: tracing::Span,
607) -> Result<
608 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
609 CompletionError,
610>
611where
612 T: HttpClientExt + Clone + 'static,
613{
614 tracing::Instrument::instrument(
615 openai_chat_completions_compatible::send_compatible_streaming_request(
616 client,
617 req,
618 LlamafileCompatibleProfile,
619 ),
620 span,
621 )
622 .await
623}
624
625#[derive(Clone)]
633pub struct EmbeddingModel<T = reqwest::Client> {
634 client: Client<T>,
635 pub model: String,
637 ndims: usize,
638}
639
640impl<T> EmbeddingModel<T> {
641 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
643 Self {
644 client,
645 model: model.into(),
646 ndims,
647 }
648 }
649}
650
651impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
652where
653 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
654{
655 const MAX_DOCUMENTS: usize = 1024;
656
657 type Client = Client<T>;
658
659 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
660 Self::new(client.clone(), model, ndims.unwrap_or_default())
661 }
662
663 fn ndims(&self) -> usize {
664 self.ndims
665 }
666
667 async fn embed_texts(
668 &self,
669 documents: impl IntoIterator<Item = String>,
670 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
671 let documents = documents.into_iter().collect::<Vec<_>>();
672
673 let body = serde_json::json!({
674 "model": self.model,
675 "input": documents,
676 });
677
678 let body = serde_json::to_vec(&body)?;
679
680 let req = self
681 .client
682 .post("v1/embeddings")?
683 .body(body)
684 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
685
686 let response = self.client.send(req).await?;
687
688 if response.status().is_success() {
689 let body: Vec<u8> = response.into_body().await?;
690 let body: ApiResponse<openai::EmbeddingResponse> = serde_json::from_slice(&body)?;
691
692 match body {
693 ApiResponse::Ok(response) => {
694 tracing::info!(target: "rig",
695 "Llamafile embedding token usage: {:?}",
696 response.usage
697 );
698
699 if response.data.len() != documents.len() {
700 return Err(EmbeddingError::ResponseError(
701 "Response data length does not match input length".into(),
702 ));
703 }
704
705 Ok(response
706 .data
707 .into_iter()
708 .zip(documents.into_iter())
709 .map(|(embedding, document)| embeddings::Embedding {
710 document,
711 vec: embedding
712 .embedding
713 .into_iter()
714 .filter_map(|n| n.as_f64())
715 .collect(),
716 })
717 .collect())
718 }
719 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
720 }
721 } else {
722 let text = http_client::text(response).await?;
723 Err(EmbeddingError::ProviderError(text))
724 }
725 }
726}
727
728#[cfg(test)]
732mod tests {
733 use super::*;
734 use crate::client::Nothing;
735 use crate::completion::Document;
736 use std::collections::HashMap;
737
738 #[test]
739 fn test_client_initialization() {
740 let _client =
741 crate::providers::llamafile::Client::new(Nothing).expect("Client::new() failed");
742 let _client_from_builder = crate::providers::llamafile::Client::builder()
743 .api_key(Nothing)
744 .build()
745 .expect("Client::builder() failed");
746 }
747
748 #[test]
749 fn test_client_from_url() {
750 let _client = crate::providers::llamafile::Client::from_url("http://localhost:8080");
751 }
752
753 #[test]
754 fn test_completion_request_conversion() {
755 use crate::OneOrMany;
756 use crate::completion::Message as CompletionMessage;
757 use crate::message::{Text, UserContent};
758
759 let completion_request = CompletionRequest {
760 model: None,
761 preamble: Some("You are a helpful assistant.".to_string()),
762 chat_history: OneOrMany::one(CompletionMessage::User {
763 content: OneOrMany::one(UserContent::Text(Text::new("Hello!".to_string()))),
764 }),
765 documents: vec![],
766 tools: vec![],
767 temperature: Some(0.7),
768 max_tokens: Some(256),
769 tool_choice: None,
770 additional_params: None,
771 output_schema: None,
772 };
773
774 let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request))
775 .expect("Failed to create request");
776
777 assert_eq!(request.model, LLAMA_CPP);
778 assert_eq!(request.messages.len(), 2); assert_eq!(
780 request.messages[0]["content"],
781 "You are a helpful assistant."
782 );
783 assert_eq!(request.messages[1]["content"], "Hello!");
784 assert_eq!(request.temperature, Some(0.7));
785 assert_eq!(request.max_tokens, Some(256));
786 }
787
788 #[test]
789 fn test_completion_request_flattens_text_only_document_arrays() {
790 use crate::completion::CompletionRequestBuilder;
791 use crate::test_utils::MockCompletionModel;
792
793 let completion_request = CompletionRequestBuilder::new(
794 MockCompletionModel::default(),
795 "What does glarb-glarb mean?",
796 )
797 .document(Document {
798 id: "doc-1".into(),
799 text: "Definition of flurbo: a green alien.".into(),
800 additional_props: HashMap::new(),
801 })
802 .document(Document {
803 id: "doc-2".into(),
804 text: "Definition of glarb-glarb: an ancient farming tool.".into(),
805 additional_props: HashMap::new(),
806 })
807 .build();
808
809 let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request))
810 .expect("Failed to create request");
811
812 assert_eq!(request.messages.len(), 2);
813 assert!(request.messages[0]["content"].is_string());
814 let documents = request.messages[0]["content"]
815 .as_str()
816 .expect("documents should serialize as a string");
817 assert!(documents.contains("Definition of flurbo"));
818 assert!(documents.contains("Definition of glarb-glarb"));
819 }
820
821 #[test]
822 fn test_llamafile_message_value_flattens_assistant_text_content() {
823 let message = openai::Message::Assistant {
824 content: vec![openai::AssistantContent::Text {
825 text: "Tool returned the answer.".into(),
826 }],
827 reasoning: None,
828 refusal: None,
829 audio: None,
830 name: None,
831 tool_calls: vec![openai::ToolCall {
832 id: "call_1".into(),
833 r#type: openai::ToolType::Function,
834 function: openai::Function {
835 name: "weather".into(),
836 arguments: serde_json::json!({"city": "London"}),
837 },
838 }],
839 };
840
841 let value = llamafile_message_value(message).expect("message conversion should succeed");
842
843 assert_eq!(value["role"], "assistant");
844 assert_eq!(value["content"], "Tool returned the answer.");
845 assert!(value["tool_calls"].is_array());
846 }
847}