Skip to main content

rig_core/providers/
llamafile.rs

1//! Llamafile API client and Rig integration
2//!
3//! [Llamafile](https://github.com/Mozilla-Ocho/llamafile) is a Mozilla Builders project
4//! that distributes LLMs as single-file executables. When started, it exposes an
5//! OpenAI-compatible API at `http://localhost:8080/v1`.
6//!
7//! # Example
8//! ```rust,ignore
9//! use rig_core::providers::llamafile;
10//! use rig_core::completion::Prompt;
11//!
12//! // Create a new Llamafile client (defaults to http://localhost:8080)
13//! let client = llamafile::Client::from_url("http://localhost:8080");
14//!
15//! // Create an agent with a preamble
16//! let agent = client
17//!     .agent(llamafile::LLAMA_CPP)
18//!     .preamble("You are a helpful assistant.")
19//!     .build();
20//!
21//! // Prompt the agent and print the response
22//! let response = agent.prompt("Hello!").await?;
23//! println!("{response}");
24//! ```
25
26use crate::client::{
27    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::{self, HttpClientExt};
31use crate::providers::internal::openai_chat_completions_compatible::{
32    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
33};
34use crate::providers::openai::{self, StreamingToolCall};
35use crate::{
36    completion::{self, CompletionError, CompletionRequest},
37    embeddings::{self, EmbeddingError},
38    json_utils,
39};
40use bytes::Bytes;
41use serde::{Deserialize, Serialize};
42use serde_json::{Map, Value};
43use tracing::{Level, info_span};
44use tracing_futures::Instrument;
45
46// ================================================================
47// Main Llamafile Client
48// ================================================================
49const LLAMAFILE_API_BASE_URL: &str = "http://localhost:8080";
50
51/// The default model identifier reported by llamafile.
52pub const LLAMA_CPP: &str = "LLaMA_CPP";
53
54#[derive(Debug, Default, Clone, Copy)]
55pub struct LlamafileExt;
56
57#[derive(Debug, Default, Clone, Copy)]
58pub struct LlamafileBuilder;
59
60impl Provider for LlamafileExt {
61    type Builder = LlamafileBuilder;
62    const VERIFY_PATH: &'static str = "v1/models";
63}
64
65impl<H> Capabilities<H> for LlamafileExt {
66    type Completion = Capable<CompletionModel<H>>;
67    type Embeddings = Capable<EmbeddingModel<H>>;
68    type Transcription = Nothing;
69    type ModelListing = Nothing;
70    #[cfg(feature = "image")]
71    type ImageGeneration = Nothing;
72    #[cfg(feature = "audio")]
73    type AudioGeneration = Nothing;
74    type Rerank = Nothing;
75}
76
77impl DebugExt for LlamafileExt {}
78
79impl ProviderBuilder for LlamafileBuilder {
80    type Extension<H>
81        = LlamafileExt
82    where
83        H: HttpClientExt;
84    type ApiKey = Nothing;
85
86    const BASE_URL: &'static str = LLAMAFILE_API_BASE_URL;
87
88    fn build<H>(
89        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
90    ) -> http_client::Result<Self::Extension<H>>
91    where
92        H: HttpClientExt,
93    {
94        Ok(LlamafileExt)
95    }
96}
97
98pub type Client<H = reqwest::Client> = client::Client<LlamafileExt, H>;
99pub type ClientBuilder<H = crate::markers::Missing> =
100    client::ClientBuilder<LlamafileBuilder, Nothing, H>;
101
102impl Client {
103    /// Create a client pointing at the given llamafile base URL
104    /// (e.g. `http://localhost:8080`).
105    pub fn from_url(base_url: &str) -> crate::client::ProviderClientResult<Self> {
106        Self::builder()
107            .api_key(Nothing)
108            .base_url(base_url)
109            .build()
110            .map_err(Into::into)
111    }
112}
113
114impl ProviderClient for Client {
115    type Input = Nothing;
116    type Error = crate::client::ProviderClientError;
117
118    fn from_env() -> Result<Self, Self::Error> {
119        let api_base = crate::client::required_env_var("LLAMAFILE_API_BASE_URL")?;
120        Self::from_url(&api_base)
121    }
122
123    fn from_val(_: Self::Input) -> Result<Self, Self::Error> {
124        Self::builder().api_key(Nothing).build().map_err(Into::into)
125    }
126}
127
128// ================================================================
129// API Error Handling
130// ================================================================
131
132#[derive(Debug, Deserialize)]
133struct ApiErrorResponse {
134    message: String,
135}
136
137#[derive(Debug, Deserialize)]
138#[serde(untagged)]
139enum ApiResponse<T> {
140    Ok(T),
141    Err(ApiErrorResponse),
142}
143
144// ================================================================
145// Completion Request
146// ================================================================
147
148/// Llamafile uses the OpenAI chat completions format.
149/// We reuse the OpenAI `Message` type for maximum compatibility.
150#[derive(Debug, Serialize)]
151struct LlamafileCompletionRequest {
152    model: String,
153    messages: Vec<Value>,
154    #[serde(skip_serializing_if = "Option::is_none")]
155    temperature: Option<f64>,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    max_tokens: Option<u64>,
158    #[serde(skip_serializing_if = "Vec::is_empty")]
159    tools: Vec<openai::ToolDefinition>,
160    #[serde(flatten, skip_serializing_if = "Option::is_none")]
161    additional_params: Option<serde_json::Value>,
162}
163
164fn join_text_segments<I>(segments: I) -> String
165where
166    I: IntoIterator<Item = String>,
167{
168    let segments = segments
169        .into_iter()
170        .filter(|segment| !segment.is_empty())
171        .collect::<Vec<_>>();
172
173    if segments.is_empty() {
174        String::new()
175    } else {
176        segments.join("\n\n")
177    }
178}
179
180fn flatten_system_content(content: &crate::OneOrMany<openai::SystemContent>) -> String {
181    join_text_segments(content.iter().map(|item| item.text.clone()))
182}
183
184fn flatten_user_content(content: &crate::OneOrMany<openai::UserContent>) -> Option<String> {
185    content
186        .iter()
187        .map(|item| match item {
188            openai::UserContent::Text { text } => Some(text.clone()),
189            _ => None,
190        })
191        .collect::<Option<Vec<_>>>()
192        .map(join_text_segments)
193}
194
195fn flatten_assistant_content(content: &[openai::AssistantContent]) -> String {
196    join_text_segments(content.iter().map(|item| match item {
197        openai::AssistantContent::Text { text } => text.clone(),
198        openai::AssistantContent::Refusal { refusal } => refusal.clone(),
199    }))
200}
201
202fn optional_value<T>(value: Option<T>) -> Result<Option<Value>, CompletionError>
203where
204    T: Serialize,
205{
206    value
207        .map(serde_json::to_value)
208        .transpose()
209        .map_err(Into::into)
210}
211
212fn message_content_value<T>(
213    flattened: Option<String>,
214    original: &T,
215) -> Result<Value, CompletionError>
216where
217    T: Serialize,
218{
219    match flattened {
220        Some(text) => Ok(Value::String(text)),
221        None => Ok(serde_json::to_value(original)?),
222    }
223}
224
225fn llamafile_message_value(message: openai::Message) -> Result<Value, CompletionError> {
226    match message {
227        openai::Message::System { content, name } => {
228            let mut object = Map::new();
229            object.insert("role".into(), Value::String("system".into()));
230            object.insert(
231                "content".into(),
232                Value::String(flatten_system_content(&content)),
233            );
234            if let Some(name) = name {
235                object.insert("name".into(), Value::String(name));
236            }
237            Ok(Value::Object(object))
238        }
239        openai::Message::User { content, name } => {
240            let mut object = Map::new();
241            object.insert("role".into(), Value::String("user".into()));
242            object.insert(
243                "content".into(),
244                message_content_value(flatten_user_content(&content), &content)?,
245            );
246            if let Some(name) = name {
247                object.insert("name".into(), Value::String(name));
248            }
249            Ok(Value::Object(object))
250        }
251        openai::Message::Assistant {
252            content,
253            refusal,
254            reasoning: _,
255            audio,
256            name,
257            tool_calls,
258        } => {
259            let mut object = Map::new();
260            object.insert("role".into(), Value::String("assistant".into()));
261            object.insert(
262                "content".into(),
263                Value::String(flatten_assistant_content(&content)),
264            );
265            if let Some(refusal) = refusal {
266                object.insert("refusal".into(), Value::String(refusal));
267            }
268            if let Some(audio) = optional_value(audio)? {
269                object.insert("audio".into(), audio);
270            }
271            if let Some(name) = name {
272                object.insert("name".into(), Value::String(name));
273            }
274            if !tool_calls.is_empty() {
275                object.insert("tool_calls".into(), serde_json::to_value(tool_calls)?);
276            }
277            Ok(Value::Object(object))
278        }
279        openai::Message::ToolResult {
280            tool_call_id,
281            content,
282        } => {
283            let mut object = Map::new();
284            object.insert("role".into(), Value::String("tool".into()));
285            object.insert("tool_call_id".into(), Value::String(tool_call_id));
286            object.insert("content".into(), Value::String(content.as_text()));
287            Ok(Value::Object(object))
288        }
289    }
290}
291
292impl TryFrom<(&str, CompletionRequest)> for LlamafileCompletionRequest {
293    type Error = CompletionError;
294
295    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
296        let chat_history = req.chat_history_with_documents();
297        if req.output_schema.is_some() {
298            tracing::warn!("Structured outputs may not be supported by llamafile");
299        }
300        let model = req.model.clone().unwrap_or_else(|| model.to_string());
301
302        // Build message history.
303        let mut full_history: Vec<openai::Message> = match &req.preamble {
304            Some(preamble) => vec![openai::Message::system(preamble)],
305            None => vec![],
306        };
307
308        let chat_history: Vec<openai::Message> = chat_history
309            .into_iter()
310            .map(|msg| msg.try_into())
311            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
312            .into_iter()
313            .flatten()
314            .collect();
315
316        full_history.extend(chat_history);
317
318        Ok(Self {
319            model,
320            messages: full_history
321                .into_iter()
322                .map(llamafile_message_value)
323                .collect::<Result<Vec<_>, _>>()?,
324            temperature: req.temperature,
325            max_tokens: req.max_tokens,
326            tools: req
327                .tools
328                .into_iter()
329                .map(openai::ToolDefinition::from)
330                .collect(),
331            additional_params: req.additional_params,
332        })
333    }
334}
335
336// ================================================================
337// Completion Model
338// ================================================================
339
340/// Llamafile completion model.
341#[derive(Clone)]
342pub struct CompletionModel<T = reqwest::Client> {
343    client: Client<T>,
344    /// The model identifier (usually `LLaMA_CPP`).
345    pub model: String,
346}
347
348impl<T> CompletionModel<T> {
349    /// Create a new completion model for the given client and model name.
350    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
351        Self {
352            client,
353            model: model.into(),
354        }
355    }
356}
357
358impl<T> completion::CompletionModel for CompletionModel<T>
359where
360    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
361{
362    type Response = openai::CompletionResponse;
363    type StreamingResponse = StreamingCompletionResponse;
364    type Client = Client<T>;
365
366    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
367        Self::new(client.clone(), model)
368    }
369
370    async fn completion(
371        &self,
372        completion_request: CompletionRequest,
373    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
374        let span = if tracing::Span::current().is_disabled() {
375            info_span!(
376                target: "rig::completions",
377                "chat",
378                gen_ai.operation.name = "chat",
379                gen_ai.provider.name = "llamafile",
380                gen_ai.request.model = self.model,
381                gen_ai.system_instructions = completion_request.preamble,
382                gen_ai.response.id = tracing::field::Empty,
383                gen_ai.response.model = tracing::field::Empty,
384                gen_ai.usage.output_tokens = tracing::field::Empty,
385                gen_ai.usage.input_tokens = tracing::field::Empty,
386                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
387            )
388        } else {
389            tracing::Span::current()
390        };
391
392        let request =
393            LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
394
395        if tracing::enabled!(Level::TRACE) {
396            tracing::trace!(target: "rig::completions",
397                "Llamafile completion request: {}",
398                serde_json::to_string_pretty(&request)?
399            );
400        }
401
402        let body = serde_json::to_vec(&request)?;
403        let req = self
404            .client
405            .post("v1/chat/completions")?
406            .body(body)
407            .map_err(|e| CompletionError::HttpError(e.into()))?;
408
409        async move {
410            let response = self.client.send::<_, Bytes>(req).await?;
411            let status = response.status();
412            let response_body = response.into_body().into_future().await?.to_vec();
413
414            if status.is_success() {
415                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
416                    &response_body,
417                )? {
418                    ApiResponse::Ok(response) => {
419                        let span = tracing::Span::current();
420                        span.record("gen_ai.response.id", response.id.clone());
421                        span.record("gen_ai.response.model", response.model.clone());
422                        if let Some(ref usage) = response.usage {
423                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
424                            span.record(
425                                "gen_ai.usage.output_tokens",
426                                usage.total_tokens - usage.prompt_tokens,
427                            );
428                        }
429
430                        if tracing::enabled!(Level::TRACE) {
431                            tracing::trace!(target: "rig::completions",
432                                "Llamafile completion response: {}",
433                                serde_json::to_string_pretty(&response)?
434                            );
435                        }
436
437                        response.try_into()
438                    }
439                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
440                }
441            } else {
442                Err(CompletionError::ProviderError(
443                    String::from_utf8_lossy(&response_body).to_string(),
444                ))
445            }
446        }
447        .instrument(span)
448        .await
449    }
450
451    async fn stream(
452        &self,
453        completion_request: CompletionRequest,
454    ) -> Result<
455        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
456        CompletionError,
457    > {
458        let span = if tracing::Span::current().is_disabled() {
459            info_span!(
460                target: "rig::completions",
461                "chat_streaming",
462                gen_ai.operation.name = "chat_streaming",
463                gen_ai.provider.name = "llamafile",
464                gen_ai.request.model = self.model,
465                gen_ai.system_instructions = completion_request.preamble,
466                gen_ai.response.id = tracing::field::Empty,
467                gen_ai.response.model = tracing::field::Empty,
468                gen_ai.usage.output_tokens = tracing::field::Empty,
469                gen_ai.usage.input_tokens = tracing::field::Empty,
470            )
471        } else {
472            tracing::Span::current()
473        };
474
475        let mut request =
476            LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
477
478        let params = json_utils::merge(
479            request.additional_params.unwrap_or(serde_json::json!({})),
480            serde_json::json!({"stream": true}),
481        );
482        request.additional_params = Some(params);
483
484        if tracing::enabled!(Level::TRACE) {
485            tracing::trace!(target: "rig::completions",
486                "Llamafile streaming completion request: {}",
487                serde_json::to_string_pretty(&request)?
488            );
489        }
490
491        let body = serde_json::to_vec(&request)?;
492        let req = self
493            .client
494            .post("v1/chat/completions")?
495            .body(body)
496            .map_err(|e| CompletionError::HttpError(e.into()))?;
497
498        send_streaming_request(self.client.clone(), req, span).await
499    }
500}
501
502// ================================================================
503// Streaming Support
504// ================================================================
505
506#[derive(Deserialize, Debug)]
507struct StreamingDelta {
508    #[serde(default)]
509    content: Option<String>,
510    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
511    tool_calls: Vec<StreamingToolCall>,
512}
513
514#[derive(Deserialize, Debug)]
515struct StreamingChoice {
516    delta: StreamingDelta,
517    #[serde(default)]
518    finish_reason: Option<openai::completion::streaming::FinishReason>,
519}
520
521#[derive(Deserialize, Debug)]
522struct StreamingCompletionChunk {
523    id: Option<String>,
524    model: Option<String>,
525    choices: Vec<StreamingChoice>,
526    usage: Option<openai::Usage>,
527}
528
529/// Final streaming response containing usage information.
530#[derive(Clone, Deserialize, Serialize, Debug)]
531pub struct StreamingCompletionResponse {
532    /// Token usage from the streaming response.
533    pub usage: openai::Usage,
534}
535
536impl GetTokenUsage for StreamingCompletionResponse {
537    fn token_usage(&self) -> crate::completion::Usage {
538        self.usage.token_usage()
539    }
540}
541
542#[derive(Clone, Copy)]
543struct LlamafileCompatibleProfile;
544
545impl CompatibleStreamProfile for LlamafileCompatibleProfile {
546    type Usage = openai::Usage;
547    type Detail = ();
548    type FinalResponse = StreamingCompletionResponse;
549
550    fn normalize_chunk(
551        &self,
552        data: &str,
553    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
554        let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
555            Ok(data) => data,
556            Err(error) => {
557                tracing::debug!(
558                    ?error,
559                    "Couldn't parse SSE payload as StreamingCompletionChunk"
560                );
561                return Ok(None);
562            }
563        };
564
565        Ok(Some(
566            openai_chat_completions_compatible::normalize_first_choice_chunk(
567                data.id,
568                data.model,
569                data.usage,
570                &data.choices,
571                |choice| CompatibleChoiceData {
572                    finish_reason: if choice.finish_reason
573                        == Some(openai::completion::streaming::FinishReason::ToolCalls)
574                    {
575                        CompatibleFinishReason::ToolCalls
576                    } else {
577                        CompatibleFinishReason::Other
578                    },
579                    text: choice.delta.content.clone(),
580                    reasoning: None,
581                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
582                        &choice.delta.tool_calls,
583                    ),
584                    details: Vec::new(),
585                },
586            ),
587        ))
588    }
589
590    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
591        StreamingCompletionResponse { usage }
592    }
593
594    fn uses_distinct_tool_call_eviction(&self) -> bool {
595        true
596    }
597
598    fn emits_complete_single_chunk_tool_calls(&self) -> bool {
599        true
600    }
601}
602
603async fn send_streaming_request<T>(
604    client: T,
605    req: http::Request<Vec<u8>>,
606    span: tracing::Span,
607) -> Result<
608    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
609    CompletionError,
610>
611where
612    T: HttpClientExt + Clone + 'static,
613{
614    tracing::Instrument::instrument(
615        openai_chat_completions_compatible::send_compatible_streaming_request(
616            client,
617            req,
618            LlamafileCompatibleProfile,
619        ),
620        span,
621    )
622    .await
623}
624
625// ================================================================
626// Embedding Model
627// ================================================================
628
629/// Llamafile embedding model.
630///
631/// Llamafile supports the OpenAI-compatible `/v1/embeddings` endpoint.
632#[derive(Clone)]
633pub struct EmbeddingModel<T = reqwest::Client> {
634    client: Client<T>,
635    /// The model identifier.
636    pub model: String,
637    ndims: usize,
638}
639
640impl<T> EmbeddingModel<T> {
641    /// Create a new embedding model for the given client, model name, and dimensions.
642    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
643        Self {
644            client,
645            model: model.into(),
646            ndims,
647        }
648    }
649}
650
651impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
652where
653    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
654{
655    const MAX_DOCUMENTS: usize = 1024;
656
657    type Client = Client<T>;
658
659    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
660        Self::new(client.clone(), model, ndims.unwrap_or_default())
661    }
662
663    fn ndims(&self) -> usize {
664        self.ndims
665    }
666
667    async fn embed_texts(
668        &self,
669        documents: impl IntoIterator<Item = String>,
670    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
671        let documents = documents.into_iter().collect::<Vec<_>>();
672
673        let body = serde_json::json!({
674            "model": self.model,
675            "input": documents,
676        });
677
678        let body = serde_json::to_vec(&body)?;
679
680        let req = self
681            .client
682            .post("v1/embeddings")?
683            .body(body)
684            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
685
686        let response = self.client.send(req).await?;
687
688        if response.status().is_success() {
689            let body: Vec<u8> = response.into_body().await?;
690            let body: ApiResponse<openai::EmbeddingResponse> = serde_json::from_slice(&body)?;
691
692            match body {
693                ApiResponse::Ok(response) => {
694                    tracing::info!(target: "rig",
695                        "Llamafile embedding token usage: {:?}",
696                        response.usage
697                    );
698
699                    if response.data.len() != documents.len() {
700                        return Err(EmbeddingError::ResponseError(
701                            "Response data length does not match input length".into(),
702                        ));
703                    }
704
705                    Ok(response
706                        .data
707                        .into_iter()
708                        .zip(documents.into_iter())
709                        .map(|(embedding, document)| embeddings::Embedding {
710                            document,
711                            vec: embedding
712                                .embedding
713                                .into_iter()
714                                .filter_map(|n| n.as_f64())
715                                .collect(),
716                        })
717                        .collect())
718                }
719                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
720            }
721        } else {
722            let text = http_client::text(response).await?;
723            Err(EmbeddingError::ProviderError(text))
724        }
725    }
726}
727
728// ================================================================
729// Tests
730// ================================================================
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use crate::client::Nothing;
735    use crate::completion::Document;
736    use std::collections::HashMap;
737
738    #[test]
739    fn test_client_initialization() {
740        let _client =
741            crate::providers::llamafile::Client::new(Nothing).expect("Client::new() failed");
742        let _client_from_builder = crate::providers::llamafile::Client::builder()
743            .api_key(Nothing)
744            .build()
745            .expect("Client::builder() failed");
746    }
747
748    #[test]
749    fn test_client_from_url() {
750        let _client = crate::providers::llamafile::Client::from_url("http://localhost:8080");
751    }
752
753    #[test]
754    fn test_completion_request_conversion() {
755        use crate::OneOrMany;
756        use crate::completion::Message as CompletionMessage;
757        use crate::message::{Text, UserContent};
758
759        let completion_request = CompletionRequest {
760            model: None,
761            preamble: Some("You are a helpful assistant.".to_string()),
762            chat_history: OneOrMany::one(CompletionMessage::User {
763                content: OneOrMany::one(UserContent::Text(Text::new("Hello!".to_string()))),
764            }),
765            documents: vec![],
766            tools: vec![],
767            temperature: Some(0.7),
768            max_tokens: Some(256),
769            tool_choice: None,
770            additional_params: None,
771            output_schema: None,
772        };
773
774        let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request))
775            .expect("Failed to create request");
776
777        assert_eq!(request.model, LLAMA_CPP);
778        assert_eq!(request.messages.len(), 2); // system + user
779        assert_eq!(
780            request.messages[0]["content"],
781            "You are a helpful assistant."
782        );
783        assert_eq!(request.messages[1]["content"], "Hello!");
784        assert_eq!(request.temperature, Some(0.7));
785        assert_eq!(request.max_tokens, Some(256));
786    }
787
788    #[test]
789    fn test_completion_request_flattens_text_only_document_arrays() {
790        use crate::completion::CompletionRequestBuilder;
791        use crate::test_utils::MockCompletionModel;
792
793        let completion_request = CompletionRequestBuilder::new(
794            MockCompletionModel::default(),
795            "What does glarb-glarb mean?",
796        )
797        .document(Document {
798            id: "doc-1".into(),
799            text: "Definition of flurbo: a green alien.".into(),
800            additional_props: HashMap::new(),
801        })
802        .document(Document {
803            id: "doc-2".into(),
804            text: "Definition of glarb-glarb: an ancient farming tool.".into(),
805            additional_props: HashMap::new(),
806        })
807        .build();
808
809        let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request))
810            .expect("Failed to create request");
811
812        assert_eq!(request.messages.len(), 2);
813        assert!(request.messages[0]["content"].is_string());
814        let documents = request.messages[0]["content"]
815            .as_str()
816            .expect("documents should serialize as a string");
817        assert!(documents.contains("Definition of flurbo"));
818        assert!(documents.contains("Definition of glarb-glarb"));
819    }
820
821    #[test]
822    fn test_llamafile_message_value_flattens_assistant_text_content() {
823        let message = openai::Message::Assistant {
824            content: vec![openai::AssistantContent::Text {
825                text: "Tool returned the answer.".into(),
826            }],
827            reasoning: None,
828            refusal: None,
829            audio: None,
830            name: None,
831            tool_calls: vec![openai::ToolCall {
832                id: "call_1".into(),
833                r#type: openai::ToolType::Function,
834                function: openai::Function {
835                    name: "weather".into(),
836                    arguments: serde_json::json!({"city": "London"}),
837                },
838            }],
839        };
840
841        let value = llamafile_message_value(message).expect("message conversion should succeed");
842
843        assert_eq!(value["role"], "assistant");
844        assert_eq!(value["content"], "Tool returned the answer.");
845        assert!(value["tool_calls"].is_array());
846    }
847}