Skip to main content

rig/providers/
llamafile.rs

1//! Llamafile API client and Rig integration
2//!
3//! [Llamafile](https://github.com/Mozilla-Ocho/llamafile) is a Mozilla Builders project
4//! that distributes LLMs as single-file executables. When started, it exposes an
5//! OpenAI-compatible API at `http://localhost:8080/v1`.
6//!
7//! # Example
8//! ```rust,ignore
9//! use rig::providers::llamafile;
10//! use rig::completion::Prompt;
11//!
12//! // Create a new Llamafile client (defaults to http://localhost:8080)
13//! let client = llamafile::Client::from_url("http://localhost:8080");
14//!
15//! // Create an agent with a preamble
16//! let agent = client
17//!     .agent(llamafile::LLAMA_CPP)
18//!     .preamble("You are a helpful assistant.")
19//!     .build();
20//!
21//! // Prompt the agent and print the response
22//! let response = agent.prompt("Hello!").await?;
23//! println!("{response}");
24//! ```
25
26use crate::client::{
27    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::sse::{Event, GenericEventSource};
31use crate::http_client::{self, HttpClientExt};
32use crate::json_utils::empty_or_none;
33use crate::providers::openai::{self, StreamingToolCall};
34use crate::{
35    completion::{self, CompletionError, CompletionRequest},
36    embeddings::{self, EmbeddingError},
37    json_utils,
38};
39use async_stream::stream;
40use bytes::Bytes;
41use futures::StreamExt;
42use serde::{Deserialize, Serialize};
43use std::collections::HashMap;
44use tracing::{Level, info_span};
45use tracing_futures::Instrument;
46
47// ================================================================
48// Main Llamafile Client
49// ================================================================
50const LLAMAFILE_API_BASE_URL: &str = "http://localhost:8080";
51
52/// The default model identifier reported by llamafile.
53pub const LLAMA_CPP: &str = "LLaMA_CPP";
54
55#[derive(Debug, Default, Clone, Copy)]
56pub struct LlamafileExt;
57
58#[derive(Debug, Default, Clone, Copy)]
59pub struct LlamafileBuilder;
60
61impl Provider for LlamafileExt {
62    type Builder = LlamafileBuilder;
63    const VERIFY_PATH: &'static str = "v1/models";
64}
65
66impl<H> Capabilities<H> for LlamafileExt {
67    type Completion = Capable<CompletionModel<H>>;
68    type Embeddings = Capable<EmbeddingModel<H>>;
69    type Transcription = Nothing;
70    type ModelListing = Nothing;
71    #[cfg(feature = "image")]
72    type ImageGeneration = Nothing;
73    #[cfg(feature = "audio")]
74    type AudioGeneration = Nothing;
75}
76
77impl DebugExt for LlamafileExt {}
78
79impl ProviderBuilder for LlamafileBuilder {
80    type Extension<H>
81        = LlamafileExt
82    where
83        H: HttpClientExt;
84    type ApiKey = Nothing;
85
86    const BASE_URL: &'static str = LLAMAFILE_API_BASE_URL;
87
88    fn build<H>(
89        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
90    ) -> http_client::Result<Self::Extension<H>>
91    where
92        H: HttpClientExt,
93    {
94        Ok(LlamafileExt)
95    }
96}
97
98pub type Client<H = reqwest::Client> = client::Client<LlamafileExt, H>;
99pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<LlamafileBuilder, Nothing, H>;
100
101impl Client {
102    /// Create a client pointing at the given llamafile base URL
103    /// (e.g. `http://localhost:8080`).
104    pub fn from_url(base_url: &str) -> Self {
105        Self::builder()
106            .api_key(Nothing)
107            .base_url(base_url)
108            .build()
109            .expect("Failed to build llamafile client")
110    }
111}
112
113impl ProviderClient for Client {
114    type Input = Nothing;
115
116    fn from_env() -> Self {
117        let api_base =
118            std::env::var("LLAMAFILE_API_BASE_URL").expect("LLAMAFILE_API_BASE_URL not set");
119        Self::from_url(&api_base)
120    }
121
122    fn from_val(_: Self::Input) -> Self {
123        Self::builder().api_key(Nothing).build().unwrap()
124    }
125}
126
127// ================================================================
128// API Error Handling
129// ================================================================
130
131#[derive(Debug, Deserialize)]
132struct ApiErrorResponse {
133    message: String,
134}
135
136#[derive(Debug, Deserialize)]
137#[serde(untagged)]
138enum ApiResponse<T> {
139    Ok(T),
140    Err(ApiErrorResponse),
141}
142
143// ================================================================
144// Completion Request
145// ================================================================
146
147/// Llamafile uses the OpenAI chat completions format.
148/// We reuse the OpenAI `Message` type for maximum compatibility.
149#[derive(Debug, Serialize, Deserialize)]
150struct LlamafileCompletionRequest {
151    model: String,
152    messages: Vec<openai::Message>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    temperature: Option<f64>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    max_tokens: Option<u64>,
157    #[serde(skip_serializing_if = "Vec::is_empty")]
158    tools: Vec<openai::ToolDefinition>,
159    #[serde(flatten, skip_serializing_if = "Option::is_none")]
160    additional_params: Option<serde_json::Value>,
161}
162
163impl TryFrom<(&str, CompletionRequest)> for LlamafileCompletionRequest {
164    type Error = CompletionError;
165
166    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
167        if req.output_schema.is_some() {
168            tracing::warn!("Structured outputs may not be supported by llamafile");
169        }
170        let model = req.model.clone().unwrap_or_else(|| model.to_string());
171
172        // Build message history: preamble -> documents -> chat history
173        let mut full_history: Vec<openai::Message> = match &req.preamble {
174            Some(preamble) => vec![openai::Message::system(preamble)],
175            None => vec![],
176        };
177
178        if let Some(docs) = req.normalized_documents() {
179            let docs: Vec<openai::Message> = docs.try_into()?;
180            full_history.extend(docs);
181        }
182
183        let chat_history: Vec<openai::Message> = req
184            .chat_history
185            .clone()
186            .into_iter()
187            .map(|msg| msg.try_into())
188            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
189            .into_iter()
190            .flatten()
191            .collect();
192
193        full_history.extend(chat_history);
194
195        Ok(Self {
196            model,
197            messages: full_history,
198            temperature: req.temperature,
199            max_tokens: req.max_tokens,
200            tools: req
201                .tools
202                .into_iter()
203                .map(openai::ToolDefinition::from)
204                .collect(),
205            additional_params: req.additional_params,
206        })
207    }
208}
209
210// ================================================================
211// Completion Model
212// ================================================================
213
214/// Llamafile completion model.
215#[derive(Clone)]
216pub struct CompletionModel<T = reqwest::Client> {
217    client: Client<T>,
218    /// The model identifier (usually `LLaMA_CPP`).
219    pub model: String,
220}
221
222impl<T> CompletionModel<T> {
223    /// Create a new completion model for the given client and model name.
224    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
225        Self {
226            client,
227            model: model.into(),
228        }
229    }
230}
231
232impl<T> completion::CompletionModel for CompletionModel<T>
233where
234    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
235{
236    type Response = openai::CompletionResponse;
237    type StreamingResponse = StreamingCompletionResponse;
238    type Client = Client<T>;
239
240    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
241        Self::new(client.clone(), model)
242    }
243
244    async fn completion(
245        &self,
246        completion_request: CompletionRequest,
247    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
248        let span = if tracing::Span::current().is_disabled() {
249            info_span!(
250                target: "rig::completions",
251                "chat",
252                gen_ai.operation.name = "chat",
253                gen_ai.provider.name = "llamafile",
254                gen_ai.request.model = self.model,
255                gen_ai.system_instructions = completion_request.preamble,
256                gen_ai.response.id = tracing::field::Empty,
257                gen_ai.response.model = tracing::field::Empty,
258                gen_ai.usage.output_tokens = tracing::field::Empty,
259                gen_ai.usage.input_tokens = tracing::field::Empty,
260            )
261        } else {
262            tracing::Span::current()
263        };
264
265        let request =
266            LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
267
268        if tracing::enabled!(Level::TRACE) {
269            tracing::trace!(target: "rig::completions",
270                "Llamafile completion request: {}",
271                serde_json::to_string_pretty(&request)?
272            );
273        }
274
275        let body = serde_json::to_vec(&request)?;
276        let req = self
277            .client
278            .post("v1/chat/completions")?
279            .body(body)
280            .map_err(|e| CompletionError::HttpError(e.into()))?;
281
282        async move {
283            let response = self.client.send::<_, Bytes>(req).await?;
284            let status = response.status();
285            let response_body = response.into_body().into_future().await?.to_vec();
286
287            if status.is_success() {
288                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
289                    &response_body,
290                )? {
291                    ApiResponse::Ok(response) => {
292                        let span = tracing::Span::current();
293                        span.record("gen_ai.response.id", response.id.clone());
294                        span.record("gen_ai.response.model_name", response.model.clone());
295                        if let Some(ref usage) = response.usage {
296                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
297                            span.record(
298                                "gen_ai.usage.output_tokens",
299                                usage.total_tokens - usage.prompt_tokens,
300                            );
301                        }
302
303                        if tracing::enabled!(Level::TRACE) {
304                            tracing::trace!(target: "rig::completions",
305                                "Llamafile completion response: {}",
306                                serde_json::to_string_pretty(&response)?
307                            );
308                        }
309
310                        response.try_into()
311                    }
312                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
313                }
314            } else {
315                Err(CompletionError::ProviderError(
316                    String::from_utf8_lossy(&response_body).to_string(),
317                ))
318            }
319        }
320        .instrument(span)
321        .await
322    }
323
324    async fn stream(
325        &self,
326        completion_request: CompletionRequest,
327    ) -> Result<
328        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
329        CompletionError,
330    > {
331        let span = if tracing::Span::current().is_disabled() {
332            info_span!(
333                target: "rig::completions",
334                "chat_streaming",
335                gen_ai.operation.name = "chat_streaming",
336                gen_ai.provider.name = "llamafile",
337                gen_ai.request.model = self.model,
338                gen_ai.system_instructions = completion_request.preamble,
339                gen_ai.response.id = tracing::field::Empty,
340                gen_ai.response.model = tracing::field::Empty,
341                gen_ai.usage.output_tokens = tracing::field::Empty,
342                gen_ai.usage.input_tokens = tracing::field::Empty,
343            )
344        } else {
345            tracing::Span::current()
346        };
347
348        let mut request =
349            LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
350
351        let params = json_utils::merge(
352            request.additional_params.unwrap_or(serde_json::json!({})),
353            serde_json::json!({"stream": true}),
354        );
355        request.additional_params = Some(params);
356
357        if tracing::enabled!(Level::TRACE) {
358            tracing::trace!(target: "rig::completions",
359                "Llamafile streaming completion request: {}",
360                serde_json::to_string_pretty(&request)?
361            );
362        }
363
364        let body = serde_json::to_vec(&request)?;
365        let req = self
366            .client
367            .post("v1/chat/completions")?
368            .body(body)
369            .map_err(|e| CompletionError::HttpError(e.into()))?;
370
371        send_streaming_request(self.client.clone(), req, span).await
372    }
373}
374
375// ================================================================
376// Streaming Support
377// ================================================================
378
379#[derive(Deserialize, Debug)]
380struct StreamingDelta {
381    #[serde(default)]
382    content: Option<String>,
383    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
384    tool_calls: Vec<StreamingToolCall>,
385}
386
387#[derive(Deserialize, Debug)]
388struct StreamingChoice {
389    delta: StreamingDelta,
390}
391
392#[derive(Deserialize, Debug)]
393struct StreamingCompletionChunk {
394    choices: Vec<StreamingChoice>,
395    usage: Option<openai::Usage>,
396}
397
398/// Final streaming response containing usage information.
399#[derive(Clone, Deserialize, Serialize, Debug)]
400pub struct StreamingCompletionResponse {
401    /// Token usage from the streaming response.
402    pub usage: openai::Usage,
403}
404
405impl GetTokenUsage for StreamingCompletionResponse {
406    fn token_usage(&self) -> Option<crate::completion::Usage> {
407        let mut usage = crate::completion::Usage::new();
408        usage.input_tokens = self.usage.prompt_tokens as u64;
409        usage.total_tokens = self.usage.total_tokens as u64;
410        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
411        Some(usage)
412    }
413}
414
415async fn send_streaming_request<T>(
416    client: T,
417    req: http::Request<Vec<u8>>,
418    span: tracing::Span,
419) -> Result<
420    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
421    CompletionError,
422>
423where
424    T: HttpClientExt + Clone + 'static,
425{
426    let mut event_source = GenericEventSource::new(client, req);
427
428    let stream = stream! {
429        let span = tracing::Span::current();
430        let mut final_usage = openai::Usage {
431            prompt_tokens: 0,
432            total_tokens: 0,
433            prompt_tokens_details: None,
434        };
435        let mut text_response = String::new();
436        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
437
438        while let Some(event_result) = event_source.next().await {
439            match event_result {
440                Ok(Event::Open) => {
441                    tracing::trace!("SSE connection opened");
442                    continue;
443                }
444                Ok(Event::Message(message)) => {
445                    let data_str = message.data.trim();
446                    if data_str.is_empty() || data_str == "[DONE]" {
447                        continue;
448                    }
449
450                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
451                    let Ok(data) = parsed else {
452                        let err = parsed.unwrap_err();
453                        tracing::debug!("Couldn't parse SSE payload: {:?}", err);
454                        continue;
455                    };
456
457                    if let Some(choice) = data.choices.first() {
458                        let delta = &choice.delta;
459
460                        // Handle tool calls
461                        for tool_call in &delta.tool_calls {
462                            let function = &tool_call.function;
463
464                            // Start of tool call
465                            if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
466                                && empty_or_none(&function.arguments)
467                            {
468                                let id = tool_call.id.clone().unwrap_or_default();
469                                let name = function.name.clone().unwrap();
470                                calls.insert(tool_call.index, (id, name, String::new()));
471                            }
472                            // Continuation
473                            else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
474                                && let Some(arguments) = &function.arguments
475                                && !arguments.is_empty()
476                            {
477                                if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
478                                    let combined = format!("{}{}", existing_args, arguments);
479                                    calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
480                                }
481                            }
482                            // Complete tool call in a single chunk
483                            else {
484                                let id = tool_call.id.clone().unwrap_or_default();
485                                let name = function.name.clone().unwrap_or_default();
486                                let arguments_str = function.arguments.clone().unwrap_or_default();
487
488                                let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
489                                    tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
490                                    continue;
491                                };
492
493                                yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
494                                    crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
495                                ));
496                            }
497                        }
498
499                        // Streamed content
500                        if let Some(content) = &delta.content {
501                            text_response += content;
502                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
503                        }
504                    }
505
506                    if let Some(usage) = data.usage {
507                        final_usage = usage;
508                    }
509                }
510                Err(crate::http_client::Error::StreamEnded) => break,
511                Err(err) => {
512                    tracing::error!(?err, "SSE error");
513                    yield Err(CompletionError::ResponseError(err.to_string()));
514                    break;
515                }
516            }
517        }
518
519        event_source.close();
520
521        // Flush accumulated tool calls
522        for (_, (id, name, arguments)) in calls {
523            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
524                continue;
525            };
526            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
527                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
528            ));
529        }
530
531        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
532        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
533
534        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
535            StreamingCompletionResponse { usage: final_usage }
536        ));
537    }.instrument(span);
538
539    Ok(crate::streaming::StreamingCompletionResponse::stream(
540        Box::pin(stream),
541    ))
542}
543
544// ================================================================
545// Embedding Model
546// ================================================================
547
548/// Llamafile embedding model.
549///
550/// Llamafile supports the OpenAI-compatible `/v1/embeddings` endpoint.
551#[derive(Clone)]
552pub struct EmbeddingModel<T = reqwest::Client> {
553    client: Client<T>,
554    /// The model identifier.
555    pub model: String,
556    ndims: usize,
557}
558
559impl<T> EmbeddingModel<T> {
560    /// Create a new embedding model for the given client, model name, and dimensions.
561    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
562        Self {
563            client,
564            model: model.into(),
565            ndims,
566        }
567    }
568}
569
570impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
571where
572    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
573{
574    const MAX_DOCUMENTS: usize = 1024;
575
576    type Client = Client<T>;
577
578    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
579        Self::new(client.clone(), model, ndims.unwrap_or_default())
580    }
581
582    fn ndims(&self) -> usize {
583        self.ndims
584    }
585
586    async fn embed_texts(
587        &self,
588        documents: impl IntoIterator<Item = String>,
589    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
590        let documents = documents.into_iter().collect::<Vec<_>>();
591
592        let body = serde_json::json!({
593            "model": self.model,
594            "input": documents,
595        });
596
597        let body = serde_json::to_vec(&body)?;
598
599        let req = self
600            .client
601            .post("v1/embeddings")?
602            .body(body)
603            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
604
605        let response = self.client.send(req).await?;
606
607        if response.status().is_success() {
608            let body: Vec<u8> = response.into_body().await?;
609            let body: ApiResponse<openai::EmbeddingResponse> = serde_json::from_slice(&body)?;
610
611            match body {
612                ApiResponse::Ok(response) => {
613                    tracing::info!(target: "rig",
614                        "Llamafile embedding token usage: {:?}",
615                        response.usage
616                    );
617
618                    if response.data.len() != documents.len() {
619                        return Err(EmbeddingError::ResponseError(
620                            "Response data length does not match input length".into(),
621                        ));
622                    }
623
624                    Ok(response
625                        .data
626                        .into_iter()
627                        .zip(documents.into_iter())
628                        .map(|(embedding, document)| embeddings::Embedding {
629                            document,
630                            vec: embedding
631                                .embedding
632                                .into_iter()
633                                .filter_map(|n| n.as_f64())
634                                .collect(),
635                        })
636                        .collect())
637                }
638                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
639            }
640        } else {
641            let text = http_client::text(response).await?;
642            Err(EmbeddingError::ProviderError(text))
643        }
644    }
645}
646
647// ================================================================
648// Tests
649// ================================================================
650#[cfg(test)]
651mod tests {
652    use super::*;
653    use crate::client::Nothing;
654
655    #[test]
656    fn test_client_initialization() {
657        let _client =
658            crate::providers::llamafile::Client::new(Nothing).expect("Client::new() failed");
659        let _client_from_builder = crate::providers::llamafile::Client::builder()
660            .api_key(Nothing)
661            .build()
662            .expect("Client::builder() failed");
663    }
664
665    #[test]
666    fn test_client_from_url() {
667        let _client = crate::providers::llamafile::Client::from_url("http://localhost:8080");
668    }
669
670    #[test]
671    fn test_completion_request_conversion() {
672        use crate::OneOrMany;
673        use crate::completion::Message as CompletionMessage;
674        use crate::message::{Text, UserContent};
675
676        let completion_request = CompletionRequest {
677            model: None,
678            preamble: Some("You are a helpful assistant.".to_string()),
679            chat_history: OneOrMany::one(CompletionMessage::User {
680                content: OneOrMany::one(UserContent::Text(Text {
681                    text: "Hello!".to_string(),
682                })),
683            }),
684            documents: vec![],
685            tools: vec![],
686            temperature: Some(0.7),
687            max_tokens: Some(256),
688            tool_choice: None,
689            additional_params: None,
690            output_schema: None,
691        };
692
693        let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request))
694            .expect("Failed to create request");
695
696        assert_eq!(request.model, LLAMA_CPP);
697        assert_eq!(request.messages.len(), 2); // system + user
698        assert_eq!(request.temperature, Some(0.7));
699        assert_eq!(request.max_tokens, Some(256));
700    }
701}