Skip to main content

rig/providers/
llamafile.rs

1//! Llamafile API client and Rig integration
2//!
3//! [Llamafile](https://github.com/Mozilla-Ocho/llamafile) is a Mozilla Builders project
4//! that distributes LLMs as single-file executables. When started, it exposes an
5//! OpenAI-compatible API at `http://localhost:8080/v1`.
6//!
7//! # Example
8//! ```rust,ignore
9//! use rig::providers::llamafile;
10//! use rig::completion::Prompt;
11//!
12//! // Create a new Llamafile client (defaults to http://localhost:8080)
13//! let client = llamafile::Client::from_url("http://localhost:8080");
14//!
15//! // Create an agent with a preamble
16//! let agent = client
17//!     .agent(llamafile::LLAMA_CPP)
18//!     .preamble("You are a helpful assistant.")
19//!     .build();
20//!
21//! // Prompt the agent and print the response
22//! let response = agent.prompt("Hello!").await?;
23//! println!("{response}");
24//! ```
25
26use crate::client::{
27    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::{self, HttpClientExt};
31use crate::providers::internal::openai_chat_completions_compatible::{
32    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
33};
34use crate::providers::openai::{self, StreamingToolCall};
35use crate::{
36    completion::{self, CompletionError, CompletionRequest},
37    embeddings::{self, EmbeddingError},
38    json_utils,
39};
40use bytes::Bytes;
41use serde::{Deserialize, Serialize};
42use serde_json::{Map, Value};
43use tracing::{Level, info_span};
44use tracing_futures::Instrument;
45
46// ================================================================
47// Main Llamafile Client
48// ================================================================
49const LLAMAFILE_API_BASE_URL: &str = "http://localhost:8080";
50
51/// The default model identifier reported by llamafile.
52pub const LLAMA_CPP: &str = "LLaMA_CPP";
53
54#[derive(Debug, Default, Clone, Copy)]
55pub struct LlamafileExt;
56
57#[derive(Debug, Default, Clone, Copy)]
58pub struct LlamafileBuilder;
59
60impl Provider for LlamafileExt {
61    type Builder = LlamafileBuilder;
62    const VERIFY_PATH: &'static str = "v1/models";
63}
64
65impl<H> Capabilities<H> for LlamafileExt {
66    type Completion = Capable<CompletionModel<H>>;
67    type Embeddings = Capable<EmbeddingModel<H>>;
68    type Transcription = Nothing;
69    type ModelListing = Nothing;
70    #[cfg(feature = "image")]
71    type ImageGeneration = Nothing;
72    #[cfg(feature = "audio")]
73    type AudioGeneration = Nothing;
74}
75
76impl DebugExt for LlamafileExt {}
77
78impl ProviderBuilder for LlamafileBuilder {
79    type Extension<H>
80        = LlamafileExt
81    where
82        H: HttpClientExt;
83    type ApiKey = Nothing;
84
85    const BASE_URL: &'static str = LLAMAFILE_API_BASE_URL;
86
87    fn build<H>(
88        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
89    ) -> http_client::Result<Self::Extension<H>>
90    where
91        H: HttpClientExt,
92    {
93        Ok(LlamafileExt)
94    }
95}
96
97pub type Client<H = reqwest::Client> = client::Client<LlamafileExt, H>;
98pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<LlamafileBuilder, Nothing, H>;
99
100impl Client {
101    /// Create a client pointing at the given llamafile base URL
102    /// (e.g. `http://localhost:8080`).
103    pub fn from_url(base_url: &str) -> crate::client::ProviderClientResult<Self> {
104        Self::builder()
105            .api_key(Nothing)
106            .base_url(base_url)
107            .build()
108            .map_err(Into::into)
109    }
110}
111
112impl ProviderClient for Client {
113    type Input = Nothing;
114    type Error = crate::client::ProviderClientError;
115
116    fn from_env() -> Result<Self, Self::Error> {
117        let api_base = crate::client::required_env_var("LLAMAFILE_API_BASE_URL")?;
118        Self::from_url(&api_base)
119    }
120
121    fn from_val(_: Self::Input) -> Result<Self, Self::Error> {
122        Self::builder().api_key(Nothing).build().map_err(Into::into)
123    }
124}
125
126// ================================================================
127// API Error Handling
128// ================================================================
129
130#[derive(Debug, Deserialize)]
131struct ApiErrorResponse {
132    message: String,
133}
134
135#[derive(Debug, Deserialize)]
136#[serde(untagged)]
137enum ApiResponse<T> {
138    Ok(T),
139    Err(ApiErrorResponse),
140}
141
142// ================================================================
143// Completion Request
144// ================================================================
145
146/// Llamafile uses the OpenAI chat completions format.
147/// We reuse the OpenAI `Message` type for maximum compatibility.
148#[derive(Debug, Serialize)]
149struct LlamafileCompletionRequest {
150    model: String,
151    messages: Vec<Value>,
152    #[serde(skip_serializing_if = "Option::is_none")]
153    temperature: Option<f64>,
154    #[serde(skip_serializing_if = "Option::is_none")]
155    max_tokens: Option<u64>,
156    #[serde(skip_serializing_if = "Vec::is_empty")]
157    tools: Vec<openai::ToolDefinition>,
158    #[serde(flatten, skip_serializing_if = "Option::is_none")]
159    additional_params: Option<serde_json::Value>,
160}
161
162fn join_text_segments<I>(segments: I) -> String
163where
164    I: IntoIterator<Item = String>,
165{
166    let segments = segments
167        .into_iter()
168        .filter(|segment| !segment.is_empty())
169        .collect::<Vec<_>>();
170
171    if segments.is_empty() {
172        String::new()
173    } else {
174        segments.join("\n\n")
175    }
176}
177
178fn flatten_system_content(content: &crate::OneOrMany<openai::SystemContent>) -> String {
179    join_text_segments(content.iter().map(|item| item.text.clone()))
180}
181
182fn flatten_user_content(content: &crate::OneOrMany<openai::UserContent>) -> Option<String> {
183    content
184        .iter()
185        .map(|item| match item {
186            openai::UserContent::Text { text } => Some(text.clone()),
187            _ => None,
188        })
189        .collect::<Option<Vec<_>>>()
190        .map(join_text_segments)
191}
192
193fn flatten_assistant_content(content: &[openai::AssistantContent]) -> String {
194    join_text_segments(content.iter().map(|item| match item {
195        openai::AssistantContent::Text { text } => text.clone(),
196        openai::AssistantContent::Refusal { refusal } => refusal.clone(),
197    }))
198}
199
200fn optional_value<T>(value: Option<T>) -> Result<Option<Value>, CompletionError>
201where
202    T: Serialize,
203{
204    value
205        .map(serde_json::to_value)
206        .transpose()
207        .map_err(Into::into)
208}
209
210fn message_content_value<T>(
211    flattened: Option<String>,
212    original: &T,
213) -> Result<Value, CompletionError>
214where
215    T: Serialize,
216{
217    match flattened {
218        Some(text) => Ok(Value::String(text)),
219        None => Ok(serde_json::to_value(original)?),
220    }
221}
222
223fn llamafile_message_value(message: openai::Message) -> Result<Value, CompletionError> {
224    match message {
225        openai::Message::System { content, name } => {
226            let mut object = Map::new();
227            object.insert("role".into(), Value::String("system".into()));
228            object.insert(
229                "content".into(),
230                Value::String(flatten_system_content(&content)),
231            );
232            if let Some(name) = name {
233                object.insert("name".into(), Value::String(name));
234            }
235            Ok(Value::Object(object))
236        }
237        openai::Message::User { content, name } => {
238            let mut object = Map::new();
239            object.insert("role".into(), Value::String("user".into()));
240            object.insert(
241                "content".into(),
242                message_content_value(flatten_user_content(&content), &content)?,
243            );
244            if let Some(name) = name {
245                object.insert("name".into(), Value::String(name));
246            }
247            Ok(Value::Object(object))
248        }
249        openai::Message::Assistant {
250            content,
251            refusal,
252            reasoning: _,
253            audio,
254            name,
255            tool_calls,
256        } => {
257            let mut object = Map::new();
258            object.insert("role".into(), Value::String("assistant".into()));
259            object.insert(
260                "content".into(),
261                Value::String(flatten_assistant_content(&content)),
262            );
263            if let Some(refusal) = refusal {
264                object.insert("refusal".into(), Value::String(refusal));
265            }
266            if let Some(audio) = optional_value(audio)? {
267                object.insert("audio".into(), audio);
268            }
269            if let Some(name) = name {
270                object.insert("name".into(), Value::String(name));
271            }
272            if !tool_calls.is_empty() {
273                object.insert("tool_calls".into(), serde_json::to_value(tool_calls)?);
274            }
275            Ok(Value::Object(object))
276        }
277        openai::Message::ToolResult {
278            tool_call_id,
279            content,
280        } => {
281            let mut object = Map::new();
282            object.insert("role".into(), Value::String("tool".into()));
283            object.insert("tool_call_id".into(), Value::String(tool_call_id));
284            object.insert("content".into(), Value::String(content.as_text()));
285            Ok(Value::Object(object))
286        }
287    }
288}
289
290impl TryFrom<(&str, CompletionRequest)> for LlamafileCompletionRequest {
291    type Error = CompletionError;
292
293    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
294        if req.output_schema.is_some() {
295            tracing::warn!("Structured outputs may not be supported by llamafile");
296        }
297        let model = req.model.clone().unwrap_or_else(|| model.to_string());
298
299        // Build message history: preamble -> documents -> chat history
300        let mut full_history: Vec<openai::Message> = match &req.preamble {
301            Some(preamble) => vec![openai::Message::system(preamble)],
302            None => vec![],
303        };
304
305        if let Some(docs) = req.normalized_documents() {
306            let docs: Vec<openai::Message> = docs.try_into()?;
307            full_history.extend(docs);
308        }
309
310        let chat_history: Vec<openai::Message> = req
311            .chat_history
312            .clone()
313            .into_iter()
314            .map(|msg| msg.try_into())
315            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
316            .into_iter()
317            .flatten()
318            .collect();
319
320        full_history.extend(chat_history);
321
322        Ok(Self {
323            model,
324            messages: full_history
325                .into_iter()
326                .map(llamafile_message_value)
327                .collect::<Result<Vec<_>, _>>()?,
328            temperature: req.temperature,
329            max_tokens: req.max_tokens,
330            tools: req
331                .tools
332                .into_iter()
333                .map(openai::ToolDefinition::from)
334                .collect(),
335            additional_params: req.additional_params,
336        })
337    }
338}
339
340// ================================================================
341// Completion Model
342// ================================================================
343
344/// Llamafile completion model.
345#[derive(Clone)]
346pub struct CompletionModel<T = reqwest::Client> {
347    client: Client<T>,
348    /// The model identifier (usually `LLaMA_CPP`).
349    pub model: String,
350}
351
352impl<T> CompletionModel<T> {
353    /// Create a new completion model for the given client and model name.
354    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
355        Self {
356            client,
357            model: model.into(),
358        }
359    }
360}
361
362impl<T> completion::CompletionModel for CompletionModel<T>
363where
364    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
365{
366    type Response = openai::CompletionResponse;
367    type StreamingResponse = StreamingCompletionResponse;
368    type Client = Client<T>;
369
370    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
371        Self::new(client.clone(), model)
372    }
373
374    async fn completion(
375        &self,
376        completion_request: CompletionRequest,
377    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
378        let span = if tracing::Span::current().is_disabled() {
379            info_span!(
380                target: "rig::completions",
381                "chat",
382                gen_ai.operation.name = "chat",
383                gen_ai.provider.name = "llamafile",
384                gen_ai.request.model = self.model,
385                gen_ai.system_instructions = completion_request.preamble,
386                gen_ai.response.id = tracing::field::Empty,
387                gen_ai.response.model = tracing::field::Empty,
388                gen_ai.usage.output_tokens = tracing::field::Empty,
389                gen_ai.usage.input_tokens = tracing::field::Empty,
390                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
391            )
392        } else {
393            tracing::Span::current()
394        };
395
396        let request =
397            LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
398
399        if tracing::enabled!(Level::TRACE) {
400            tracing::trace!(target: "rig::completions",
401                "Llamafile completion request: {}",
402                serde_json::to_string_pretty(&request)?
403            );
404        }
405
406        let body = serde_json::to_vec(&request)?;
407        let req = self
408            .client
409            .post("v1/chat/completions")?
410            .body(body)
411            .map_err(|e| CompletionError::HttpError(e.into()))?;
412
413        async move {
414            let response = self.client.send::<_, Bytes>(req).await?;
415            let status = response.status();
416            let response_body = response.into_body().into_future().await?.to_vec();
417
418            if status.is_success() {
419                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
420                    &response_body,
421                )? {
422                    ApiResponse::Ok(response) => {
423                        let span = tracing::Span::current();
424                        span.record("gen_ai.response.id", response.id.clone());
425                        span.record("gen_ai.response.model", response.model.clone());
426                        if let Some(ref usage) = response.usage {
427                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
428                            span.record(
429                                "gen_ai.usage.output_tokens",
430                                usage.total_tokens - usage.prompt_tokens,
431                            );
432                        }
433
434                        if tracing::enabled!(Level::TRACE) {
435                            tracing::trace!(target: "rig::completions",
436                                "Llamafile completion response: {}",
437                                serde_json::to_string_pretty(&response)?
438                            );
439                        }
440
441                        response.try_into()
442                    }
443                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
444                }
445            } else {
446                Err(CompletionError::ProviderError(
447                    String::from_utf8_lossy(&response_body).to_string(),
448                ))
449            }
450        }
451        .instrument(span)
452        .await
453    }
454
455    async fn stream(
456        &self,
457        completion_request: CompletionRequest,
458    ) -> Result<
459        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
460        CompletionError,
461    > {
462        let span = if tracing::Span::current().is_disabled() {
463            info_span!(
464                target: "rig::completions",
465                "chat_streaming",
466                gen_ai.operation.name = "chat_streaming",
467                gen_ai.provider.name = "llamafile",
468                gen_ai.request.model = self.model,
469                gen_ai.system_instructions = completion_request.preamble,
470                gen_ai.response.id = tracing::field::Empty,
471                gen_ai.response.model = tracing::field::Empty,
472                gen_ai.usage.output_tokens = tracing::field::Empty,
473                gen_ai.usage.input_tokens = tracing::field::Empty,
474            )
475        } else {
476            tracing::Span::current()
477        };
478
479        let mut request =
480            LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
481
482        let params = json_utils::merge(
483            request.additional_params.unwrap_or(serde_json::json!({})),
484            serde_json::json!({"stream": true}),
485        );
486        request.additional_params = Some(params);
487
488        if tracing::enabled!(Level::TRACE) {
489            tracing::trace!(target: "rig::completions",
490                "Llamafile streaming completion request: {}",
491                serde_json::to_string_pretty(&request)?
492            );
493        }
494
495        let body = serde_json::to_vec(&request)?;
496        let req = self
497            .client
498            .post("v1/chat/completions")?
499            .body(body)
500            .map_err(|e| CompletionError::HttpError(e.into()))?;
501
502        send_streaming_request(self.client.clone(), req, span).await
503    }
504}
505
506// ================================================================
507// Streaming Support
508// ================================================================
509
510#[derive(Deserialize, Debug)]
511struct StreamingDelta {
512    #[serde(default)]
513    content: Option<String>,
514    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
515    tool_calls: Vec<StreamingToolCall>,
516}
517
518#[derive(Deserialize, Debug)]
519struct StreamingChoice {
520    delta: StreamingDelta,
521    #[serde(default)]
522    finish_reason: Option<openai::completion::streaming::FinishReason>,
523}
524
525#[derive(Deserialize, Debug)]
526struct StreamingCompletionChunk {
527    id: Option<String>,
528    model: Option<String>,
529    choices: Vec<StreamingChoice>,
530    usage: Option<openai::Usage>,
531}
532
533/// Final streaming response containing usage information.
534#[derive(Clone, Deserialize, Serialize, Debug)]
535pub struct StreamingCompletionResponse {
536    /// Token usage from the streaming response.
537    pub usage: openai::Usage,
538}
539
540impl GetTokenUsage for StreamingCompletionResponse {
541    fn token_usage(&self) -> Option<crate::completion::Usage> {
542        self.usage.token_usage()
543    }
544}
545
546#[derive(Clone, Copy)]
547struct LlamafileCompatibleProfile;
548
549impl CompatibleStreamProfile for LlamafileCompatibleProfile {
550    type Usage = openai::Usage;
551    type Detail = ();
552    type FinalResponse = StreamingCompletionResponse;
553
554    fn normalize_chunk(
555        &self,
556        data: &str,
557    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
558        let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
559            Ok(data) => data,
560            Err(error) => {
561                tracing::debug!(
562                    ?error,
563                    "Couldn't parse SSE payload as StreamingCompletionChunk"
564                );
565                return Ok(None);
566            }
567        };
568
569        Ok(Some(
570            openai_chat_completions_compatible::normalize_first_choice_chunk(
571                data.id,
572                data.model,
573                data.usage,
574                &data.choices,
575                |choice| CompatibleChoiceData {
576                    finish_reason: if choice.finish_reason
577                        == Some(openai::completion::streaming::FinishReason::ToolCalls)
578                    {
579                        CompatibleFinishReason::ToolCalls
580                    } else {
581                        CompatibleFinishReason::Other
582                    },
583                    text: choice.delta.content.clone(),
584                    reasoning: None,
585                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
586                        &choice.delta.tool_calls,
587                    ),
588                    details: Vec::new(),
589                },
590            ),
591        ))
592    }
593
594    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
595        StreamingCompletionResponse { usage }
596    }
597
598    fn uses_distinct_tool_call_eviction(&self) -> bool {
599        true
600    }
601
602    fn emits_complete_single_chunk_tool_calls(&self) -> bool {
603        true
604    }
605}
606
607async fn send_streaming_request<T>(
608    client: T,
609    req: http::Request<Vec<u8>>,
610    span: tracing::Span,
611) -> Result<
612    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
613    CompletionError,
614>
615where
616    T: HttpClientExt + Clone + 'static,
617{
618    tracing::Instrument::instrument(
619        openai_chat_completions_compatible::send_compatible_streaming_request(
620            client,
621            req,
622            LlamafileCompatibleProfile,
623        ),
624        span,
625    )
626    .await
627}
628
629// ================================================================
630// Embedding Model
631// ================================================================
632
633/// Llamafile embedding model.
634///
635/// Llamafile supports the OpenAI-compatible `/v1/embeddings` endpoint.
636#[derive(Clone)]
637pub struct EmbeddingModel<T = reqwest::Client> {
638    client: Client<T>,
639    /// The model identifier.
640    pub model: String,
641    ndims: usize,
642}
643
644impl<T> EmbeddingModel<T> {
645    /// Create a new embedding model for the given client, model name, and dimensions.
646    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
647        Self {
648            client,
649            model: model.into(),
650            ndims,
651        }
652    }
653}
654
655impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
656where
657    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
658{
659    const MAX_DOCUMENTS: usize = 1024;
660
661    type Client = Client<T>;
662
663    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
664        Self::new(client.clone(), model, ndims.unwrap_or_default())
665    }
666
667    fn ndims(&self) -> usize {
668        self.ndims
669    }
670
671    async fn embed_texts(
672        &self,
673        documents: impl IntoIterator<Item = String>,
674    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
675        let documents = documents.into_iter().collect::<Vec<_>>();
676
677        let body = serde_json::json!({
678            "model": self.model,
679            "input": documents,
680        });
681
682        let body = serde_json::to_vec(&body)?;
683
684        let req = self
685            .client
686            .post("v1/embeddings")?
687            .body(body)
688            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
689
690        let response = self.client.send(req).await?;
691
692        if response.status().is_success() {
693            let body: Vec<u8> = response.into_body().await?;
694            let body: ApiResponse<openai::EmbeddingResponse> = serde_json::from_slice(&body)?;
695
696            match body {
697                ApiResponse::Ok(response) => {
698                    tracing::info!(target: "rig",
699                        "Llamafile embedding token usage: {:?}",
700                        response.usage
701                    );
702
703                    if response.data.len() != documents.len() {
704                        return Err(EmbeddingError::ResponseError(
705                            "Response data length does not match input length".into(),
706                        ));
707                    }
708
709                    Ok(response
710                        .data
711                        .into_iter()
712                        .zip(documents.into_iter())
713                        .map(|(embedding, document)| embeddings::Embedding {
714                            document,
715                            vec: embedding
716                                .embedding
717                                .into_iter()
718                                .filter_map(|n| n.as_f64())
719                                .collect(),
720                        })
721                        .collect())
722                }
723                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
724            }
725        } else {
726            let text = http_client::text(response).await?;
727            Err(EmbeddingError::ProviderError(text))
728        }
729    }
730}
731
732// ================================================================
733// Tests
734// ================================================================
735#[cfg(test)]
736mod tests {
737    use super::*;
738    use crate::client::Nothing;
739    use crate::completion::Document;
740    use std::collections::HashMap;
741
742    #[test]
743    fn test_client_initialization() {
744        let _client =
745            crate::providers::llamafile::Client::new(Nothing).expect("Client::new() failed");
746        let _client_from_builder = crate::providers::llamafile::Client::builder()
747            .api_key(Nothing)
748            .build()
749            .expect("Client::builder() failed");
750    }
751
752    #[test]
753    fn test_client_from_url() {
754        let _client = crate::providers::llamafile::Client::from_url("http://localhost:8080");
755    }
756
757    #[test]
758    fn test_completion_request_conversion() {
759        use crate::OneOrMany;
760        use crate::completion::Message as CompletionMessage;
761        use crate::message::{Text, UserContent};
762
763        let completion_request = CompletionRequest {
764            model: None,
765            preamble: Some("You are a helpful assistant.".to_string()),
766            chat_history: OneOrMany::one(CompletionMessage::User {
767                content: OneOrMany::one(UserContent::Text(Text {
768                    text: "Hello!".to_string(),
769                })),
770            }),
771            documents: vec![],
772            tools: vec![],
773            temperature: Some(0.7),
774            max_tokens: Some(256),
775            tool_choice: None,
776            additional_params: None,
777            output_schema: None,
778        };
779
780        let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request))
781            .expect("Failed to create request");
782
783        assert_eq!(request.model, LLAMA_CPP);
784        assert_eq!(request.messages.len(), 2); // system + user
785        assert_eq!(
786            request.messages[0]["content"],
787            "You are a helpful assistant."
788        );
789        assert_eq!(request.messages[1]["content"], "Hello!");
790        assert_eq!(request.temperature, Some(0.7));
791        assert_eq!(request.max_tokens, Some(256));
792    }
793
794    #[test]
795    fn test_completion_request_flattens_text_only_document_arrays() {
796        use crate::OneOrMany;
797        use crate::completion::Message as CompletionMessage;
798        use crate::message::{Text, UserContent};
799
800        let completion_request = CompletionRequest {
801            model: None,
802            preamble: None,
803            chat_history: OneOrMany::one(CompletionMessage::User {
804                content: OneOrMany::one(UserContent::Text(Text {
805                    text: "What does glarb-glarb mean?".to_string(),
806                })),
807            }),
808            documents: vec![
809                Document {
810                    id: "doc-1".into(),
811                    text: "Definition of flurbo: a green alien.".into(),
812                    additional_props: HashMap::new(),
813                },
814                Document {
815                    id: "doc-2".into(),
816                    text: "Definition of glarb-glarb: an ancient farming tool.".into(),
817                    additional_props: HashMap::new(),
818                },
819            ],
820            tools: vec![],
821            temperature: None,
822            max_tokens: None,
823            tool_choice: None,
824            additional_params: None,
825            output_schema: None,
826        };
827
828        let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request))
829            .expect("Failed to create request");
830
831        assert_eq!(request.messages.len(), 2);
832        assert!(request.messages[0]["content"].is_string());
833        let documents = request.messages[0]["content"]
834            .as_str()
835            .expect("documents should serialize as a string");
836        assert!(documents.contains("Definition of flurbo"));
837        assert!(documents.contains("Definition of glarb-glarb"));
838    }
839
840    #[test]
841    fn test_llamafile_message_value_flattens_assistant_text_content() {
842        let message = openai::Message::Assistant {
843            content: vec![openai::AssistantContent::Text {
844                text: "Tool returned the answer.".into(),
845            }],
846            reasoning: None,
847            refusal: None,
848            audio: None,
849            name: None,
850            tool_calls: vec![openai::ToolCall {
851                id: "call_1".into(),
852                r#type: openai::ToolType::Function,
853                function: openai::Function {
854                    name: "weather".into(),
855                    arguments: serde_json::json!({"city": "London"}),
856                },
857            }],
858        };
859
860        let value = llamafile_message_value(message).expect("message conversion should succeed");
861
862        assert_eq!(value["role"], "assistant");
863        assert_eq!(value["content"], "Tool returned the answer.");
864        assert!(value["tool_calls"].is_array());
865    }
866}