Skip to main content

rig_core/providers/
llamafile.rs

1//! Llamafile API client and Rig integration
2//!
3//! [Llamafile](https://github.com/Mozilla-Ocho/llamafile) is a Mozilla Builders project
4//! that distributes LLMs as single-file executables. When started, it exposes an
5//! OpenAI-compatible API at `http://localhost:8080/v1`.
6//!
7//! # Example
8//! ```rust,ignore
9//! use rig_core::providers::llamafile;
10//! use rig_core::completion::Prompt;
11//!
12//! // Create a new Llamafile client (defaults to http://localhost:8080)
13//! let client = llamafile::Client::from_url("http://localhost:8080");
14//!
15//! // Create an agent with a preamble
16//! let agent = client
17//!     .agent(llamafile::LLAMA_CPP)
18//!     .preamble("You are a helpful assistant.")
19//!     .build();
20//!
21//! // Prompt the agent and print the response
22//! let response = agent.prompt("Hello!").await?;
23//! println!("{response}");
24//! ```
25
26use crate::client::{
27    self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient,
28};
29use crate::completion::GetTokenUsage;
30use crate::http_client::{self, HttpClientExt};
31use crate::providers::internal::openai_chat_completions_compatible::{
32    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
33};
34use crate::providers::openai::{self, StreamingToolCall};
35use crate::{
36    completion::{self, CompletionError, CompletionRequest},
37    embeddings::{self, EmbeddingError},
38    json_utils,
39};
40use bytes::Bytes;
41use serde::{Deserialize, Serialize};
42use serde_json::{Map, Value};
43use tracing::{Level, info_span};
44use tracing_futures::Instrument;
45
46// ================================================================
47// Main Llamafile Client
48// ================================================================
49const LLAMAFILE_API_BASE_URL: &str = "http://localhost:8080";
50
51/// The default model identifier reported by llamafile.
52pub const LLAMA_CPP: &str = "LLaMA_CPP";
53
54#[derive(Debug, Default, Clone, Copy)]
55pub struct LlamafileExt;
56
57#[derive(Debug, Default, Clone, Copy)]
58pub struct LlamafileBuilder;
59
60impl Provider for LlamafileExt {
61    type Builder = LlamafileBuilder;
62    const VERIFY_PATH: &'static str = "v1/models";
63}
64
65impl<H> Capabilities<H> for LlamafileExt {
66    type Completion = Capable<CompletionModel<H>>;
67    type Embeddings = Capable<EmbeddingModel<H>>;
68    type Transcription = Nothing;
69    type ModelListing = Nothing;
70    #[cfg(feature = "image")]
71    type ImageGeneration = Nothing;
72    #[cfg(feature = "audio")]
73    type AudioGeneration = Nothing;
74}
75
76impl DebugExt for LlamafileExt {}
77
78impl ProviderBuilder for LlamafileBuilder {
79    type Extension<H>
80        = LlamafileExt
81    where
82        H: HttpClientExt;
83    type ApiKey = Nothing;
84
85    const BASE_URL: &'static str = LLAMAFILE_API_BASE_URL;
86
87    fn build<H>(
88        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
89    ) -> http_client::Result<Self::Extension<H>>
90    where
91        H: HttpClientExt,
92    {
93        Ok(LlamafileExt)
94    }
95}
96
97pub type Client<H = reqwest::Client> = client::Client<LlamafileExt, H>;
98pub type ClientBuilder<H = crate::markers::Missing> =
99    client::ClientBuilder<LlamafileBuilder, Nothing, H>;
100
101impl Client {
102    /// Create a client pointing at the given llamafile base URL
103    /// (e.g. `http://localhost:8080`).
104    pub fn from_url(base_url: &str) -> crate::client::ProviderClientResult<Self> {
105        Self::builder()
106            .api_key(Nothing)
107            .base_url(base_url)
108            .build()
109            .map_err(Into::into)
110    }
111}
112
113impl ProviderClient for Client {
114    type Input = Nothing;
115    type Error = crate::client::ProviderClientError;
116
117    fn from_env() -> Result<Self, Self::Error> {
118        let api_base = crate::client::required_env_var("LLAMAFILE_API_BASE_URL")?;
119        Self::from_url(&api_base)
120    }
121
122    fn from_val(_: Self::Input) -> Result<Self, Self::Error> {
123        Self::builder().api_key(Nothing).build().map_err(Into::into)
124    }
125}
126
127// ================================================================
128// API Error Handling
129// ================================================================
130
131#[derive(Debug, Deserialize)]
132struct ApiErrorResponse {
133    message: String,
134}
135
136#[derive(Debug, Deserialize)]
137#[serde(untagged)]
138enum ApiResponse<T> {
139    Ok(T),
140    Err(ApiErrorResponse),
141}
142
143// ================================================================
144// Completion Request
145// ================================================================
146
147/// Llamafile uses the OpenAI chat completions format.
148/// We reuse the OpenAI `Message` type for maximum compatibility.
149#[derive(Debug, Serialize)]
150struct LlamafileCompletionRequest {
151    model: String,
152    messages: Vec<Value>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    temperature: Option<f64>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    max_tokens: Option<u64>,
157    #[serde(skip_serializing_if = "Vec::is_empty")]
158    tools: Vec<openai::ToolDefinition>,
159    #[serde(flatten, skip_serializing_if = "Option::is_none")]
160    additional_params: Option<serde_json::Value>,
161}
162
163fn join_text_segments<I>(segments: I) -> String
164where
165    I: IntoIterator<Item = String>,
166{
167    let segments = segments
168        .into_iter()
169        .filter(|segment| !segment.is_empty())
170        .collect::<Vec<_>>();
171
172    if segments.is_empty() {
173        String::new()
174    } else {
175        segments.join("\n\n")
176    }
177}
178
179fn flatten_system_content(content: &crate::OneOrMany<openai::SystemContent>) -> String {
180    join_text_segments(content.iter().map(|item| item.text.clone()))
181}
182
183fn flatten_user_content(content: &crate::OneOrMany<openai::UserContent>) -> Option<String> {
184    content
185        .iter()
186        .map(|item| match item {
187            openai::UserContent::Text { text } => Some(text.clone()),
188            _ => None,
189        })
190        .collect::<Option<Vec<_>>>()
191        .map(join_text_segments)
192}
193
194fn flatten_assistant_content(content: &[openai::AssistantContent]) -> String {
195    join_text_segments(content.iter().map(|item| match item {
196        openai::AssistantContent::Text { text } => text.clone(),
197        openai::AssistantContent::Refusal { refusal } => refusal.clone(),
198    }))
199}
200
201fn optional_value<T>(value: Option<T>) -> Result<Option<Value>, CompletionError>
202where
203    T: Serialize,
204{
205    value
206        .map(serde_json::to_value)
207        .transpose()
208        .map_err(Into::into)
209}
210
211fn message_content_value<T>(
212    flattened: Option<String>,
213    original: &T,
214) -> Result<Value, CompletionError>
215where
216    T: Serialize,
217{
218    match flattened {
219        Some(text) => Ok(Value::String(text)),
220        None => Ok(serde_json::to_value(original)?),
221    }
222}
223
224fn llamafile_message_value(message: openai::Message) -> Result<Value, CompletionError> {
225    match message {
226        openai::Message::System { content, name } => {
227            let mut object = Map::new();
228            object.insert("role".into(), Value::String("system".into()));
229            object.insert(
230                "content".into(),
231                Value::String(flatten_system_content(&content)),
232            );
233            if let Some(name) = name {
234                object.insert("name".into(), Value::String(name));
235            }
236            Ok(Value::Object(object))
237        }
238        openai::Message::User { content, name } => {
239            let mut object = Map::new();
240            object.insert("role".into(), Value::String("user".into()));
241            object.insert(
242                "content".into(),
243                message_content_value(flatten_user_content(&content), &content)?,
244            );
245            if let Some(name) = name {
246                object.insert("name".into(), Value::String(name));
247            }
248            Ok(Value::Object(object))
249        }
250        openai::Message::Assistant {
251            content,
252            refusal,
253            reasoning: _,
254            audio,
255            name,
256            tool_calls,
257        } => {
258            let mut object = Map::new();
259            object.insert("role".into(), Value::String("assistant".into()));
260            object.insert(
261                "content".into(),
262                Value::String(flatten_assistant_content(&content)),
263            );
264            if let Some(refusal) = refusal {
265                object.insert("refusal".into(), Value::String(refusal));
266            }
267            if let Some(audio) = optional_value(audio)? {
268                object.insert("audio".into(), audio);
269            }
270            if let Some(name) = name {
271                object.insert("name".into(), Value::String(name));
272            }
273            if !tool_calls.is_empty() {
274                object.insert("tool_calls".into(), serde_json::to_value(tool_calls)?);
275            }
276            Ok(Value::Object(object))
277        }
278        openai::Message::ToolResult {
279            tool_call_id,
280            content,
281        } => {
282            let mut object = Map::new();
283            object.insert("role".into(), Value::String("tool".into()));
284            object.insert("tool_call_id".into(), Value::String(tool_call_id));
285            object.insert("content".into(), Value::String(content.as_text()));
286            Ok(Value::Object(object))
287        }
288    }
289}
290
291impl TryFrom<(&str, CompletionRequest)> for LlamafileCompletionRequest {
292    type Error = CompletionError;
293
294    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
295        if req.output_schema.is_some() {
296            tracing::warn!("Structured outputs may not be supported by llamafile");
297        }
298        let model = req.model.clone().unwrap_or_else(|| model.to_string());
299
300        // Build message history: preamble -> documents -> chat history
301        let mut full_history: Vec<openai::Message> = match &req.preamble {
302            Some(preamble) => vec![openai::Message::system(preamble)],
303            None => vec![],
304        };
305
306        if let Some(docs) = req.normalized_documents() {
307            let docs: Vec<openai::Message> = docs.try_into()?;
308            full_history.extend(docs);
309        }
310
311        let chat_history: Vec<openai::Message> = req
312            .chat_history
313            .clone()
314            .into_iter()
315            .map(|msg| msg.try_into())
316            .collect::<Result<Vec<Vec<openai::Message>>, _>>()?
317            .into_iter()
318            .flatten()
319            .collect();
320
321        full_history.extend(chat_history);
322
323        Ok(Self {
324            model,
325            messages: full_history
326                .into_iter()
327                .map(llamafile_message_value)
328                .collect::<Result<Vec<_>, _>>()?,
329            temperature: req.temperature,
330            max_tokens: req.max_tokens,
331            tools: req
332                .tools
333                .into_iter()
334                .map(openai::ToolDefinition::from)
335                .collect(),
336            additional_params: req.additional_params,
337        })
338    }
339}
340
341// ================================================================
342// Completion Model
343// ================================================================
344
345/// Llamafile completion model.
346#[derive(Clone)]
347pub struct CompletionModel<T = reqwest::Client> {
348    client: Client<T>,
349    /// The model identifier (usually `LLaMA_CPP`).
350    pub model: String,
351}
352
353impl<T> CompletionModel<T> {
354    /// Create a new completion model for the given client and model name.
355    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
356        Self {
357            client,
358            model: model.into(),
359        }
360    }
361}
362
363impl<T> completion::CompletionModel for CompletionModel<T>
364where
365    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
366{
367    type Response = openai::CompletionResponse;
368    type StreamingResponse = StreamingCompletionResponse;
369    type Client = Client<T>;
370
371    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
372        Self::new(client.clone(), model)
373    }
374
375    async fn completion(
376        &self,
377        completion_request: CompletionRequest,
378    ) -> Result<completion::CompletionResponse<openai::CompletionResponse>, CompletionError> {
379        let span = if tracing::Span::current().is_disabled() {
380            info_span!(
381                target: "rig::completions",
382                "chat",
383                gen_ai.operation.name = "chat",
384                gen_ai.provider.name = "llamafile",
385                gen_ai.request.model = self.model,
386                gen_ai.system_instructions = completion_request.preamble,
387                gen_ai.response.id = tracing::field::Empty,
388                gen_ai.response.model = tracing::field::Empty,
389                gen_ai.usage.output_tokens = tracing::field::Empty,
390                gen_ai.usage.input_tokens = tracing::field::Empty,
391                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
392            )
393        } else {
394            tracing::Span::current()
395        };
396
397        let request =
398            LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
399
400        if tracing::enabled!(Level::TRACE) {
401            tracing::trace!(target: "rig::completions",
402                "Llamafile completion request: {}",
403                serde_json::to_string_pretty(&request)?
404            );
405        }
406
407        let body = serde_json::to_vec(&request)?;
408        let req = self
409            .client
410            .post("v1/chat/completions")?
411            .body(body)
412            .map_err(|e| CompletionError::HttpError(e.into()))?;
413
414        async move {
415            let response = self.client.send::<_, Bytes>(req).await?;
416            let status = response.status();
417            let response_body = response.into_body().into_future().await?.to_vec();
418
419            if status.is_success() {
420                match serde_json::from_slice::<ApiResponse<openai::CompletionResponse>>(
421                    &response_body,
422                )? {
423                    ApiResponse::Ok(response) => {
424                        let span = tracing::Span::current();
425                        span.record("gen_ai.response.id", response.id.clone());
426                        span.record("gen_ai.response.model", response.model.clone());
427                        if let Some(ref usage) = response.usage {
428                            span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
429                            span.record(
430                                "gen_ai.usage.output_tokens",
431                                usage.total_tokens - usage.prompt_tokens,
432                            );
433                        }
434
435                        if tracing::enabled!(Level::TRACE) {
436                            tracing::trace!(target: "rig::completions",
437                                "Llamafile completion response: {}",
438                                serde_json::to_string_pretty(&response)?
439                            );
440                        }
441
442                        response.try_into()
443                    }
444                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
445                }
446            } else {
447                Err(CompletionError::ProviderError(
448                    String::from_utf8_lossy(&response_body).to_string(),
449                ))
450            }
451        }
452        .instrument(span)
453        .await
454    }
455
456    async fn stream(
457        &self,
458        completion_request: CompletionRequest,
459    ) -> Result<
460        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
461        CompletionError,
462    > {
463        let span = if tracing::Span::current().is_disabled() {
464            info_span!(
465                target: "rig::completions",
466                "chat_streaming",
467                gen_ai.operation.name = "chat_streaming",
468                gen_ai.provider.name = "llamafile",
469                gen_ai.request.model = self.model,
470                gen_ai.system_instructions = completion_request.preamble,
471                gen_ai.response.id = tracing::field::Empty,
472                gen_ai.response.model = tracing::field::Empty,
473                gen_ai.usage.output_tokens = tracing::field::Empty,
474                gen_ai.usage.input_tokens = tracing::field::Empty,
475            )
476        } else {
477            tracing::Span::current()
478        };
479
480        let mut request =
481            LlamafileCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
482
483        let params = json_utils::merge(
484            request.additional_params.unwrap_or(serde_json::json!({})),
485            serde_json::json!({"stream": true}),
486        );
487        request.additional_params = Some(params);
488
489        if tracing::enabled!(Level::TRACE) {
490            tracing::trace!(target: "rig::completions",
491                "Llamafile streaming completion request: {}",
492                serde_json::to_string_pretty(&request)?
493            );
494        }
495
496        let body = serde_json::to_vec(&request)?;
497        let req = self
498            .client
499            .post("v1/chat/completions")?
500            .body(body)
501            .map_err(|e| CompletionError::HttpError(e.into()))?;
502
503        send_streaming_request(self.client.clone(), req, span).await
504    }
505}
506
507// ================================================================
508// Streaming Support
509// ================================================================
510
511#[derive(Deserialize, Debug)]
512struct StreamingDelta {
513    #[serde(default)]
514    content: Option<String>,
515    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
516    tool_calls: Vec<StreamingToolCall>,
517}
518
519#[derive(Deserialize, Debug)]
520struct StreamingChoice {
521    delta: StreamingDelta,
522    #[serde(default)]
523    finish_reason: Option<openai::completion::streaming::FinishReason>,
524}
525
526#[derive(Deserialize, Debug)]
527struct StreamingCompletionChunk {
528    id: Option<String>,
529    model: Option<String>,
530    choices: Vec<StreamingChoice>,
531    usage: Option<openai::Usage>,
532}
533
534/// Final streaming response containing usage information.
535#[derive(Clone, Deserialize, Serialize, Debug)]
536pub struct StreamingCompletionResponse {
537    /// Token usage from the streaming response.
538    pub usage: openai::Usage,
539}
540
541impl GetTokenUsage for StreamingCompletionResponse {
542    fn token_usage(&self) -> Option<crate::completion::Usage> {
543        self.usage.token_usage()
544    }
545}
546
547#[derive(Clone, Copy)]
548struct LlamafileCompatibleProfile;
549
550impl CompatibleStreamProfile for LlamafileCompatibleProfile {
551    type Usage = openai::Usage;
552    type Detail = ();
553    type FinalResponse = StreamingCompletionResponse;
554
555    fn normalize_chunk(
556        &self,
557        data: &str,
558    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
559        let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
560            Ok(data) => data,
561            Err(error) => {
562                tracing::debug!(
563                    ?error,
564                    "Couldn't parse SSE payload as StreamingCompletionChunk"
565                );
566                return Ok(None);
567            }
568        };
569
570        Ok(Some(
571            openai_chat_completions_compatible::normalize_first_choice_chunk(
572                data.id,
573                data.model,
574                data.usage,
575                &data.choices,
576                |choice| CompatibleChoiceData {
577                    finish_reason: if choice.finish_reason
578                        == Some(openai::completion::streaming::FinishReason::ToolCalls)
579                    {
580                        CompatibleFinishReason::ToolCalls
581                    } else {
582                        CompatibleFinishReason::Other
583                    },
584                    text: choice.delta.content.clone(),
585                    reasoning: None,
586                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
587                        &choice.delta.tool_calls,
588                    ),
589                    details: Vec::new(),
590                },
591            ),
592        ))
593    }
594
595    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
596        StreamingCompletionResponse { usage }
597    }
598
599    fn uses_distinct_tool_call_eviction(&self) -> bool {
600        true
601    }
602
603    fn emits_complete_single_chunk_tool_calls(&self) -> bool {
604        true
605    }
606}
607
608async fn send_streaming_request<T>(
609    client: T,
610    req: http::Request<Vec<u8>>,
611    span: tracing::Span,
612) -> Result<
613    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
614    CompletionError,
615>
616where
617    T: HttpClientExt + Clone + 'static,
618{
619    tracing::Instrument::instrument(
620        openai_chat_completions_compatible::send_compatible_streaming_request(
621            client,
622            req,
623            LlamafileCompatibleProfile,
624        ),
625        span,
626    )
627    .await
628}
629
630// ================================================================
631// Embedding Model
632// ================================================================
633
634/// Llamafile embedding model.
635///
636/// Llamafile supports the OpenAI-compatible `/v1/embeddings` endpoint.
637#[derive(Clone)]
638pub struct EmbeddingModel<T = reqwest::Client> {
639    client: Client<T>,
640    /// The model identifier.
641    pub model: String,
642    ndims: usize,
643}
644
645impl<T> EmbeddingModel<T> {
646    /// Create a new embedding model for the given client, model name, and dimensions.
647    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
648        Self {
649            client,
650            model: model.into(),
651            ndims,
652        }
653    }
654}
655
656impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
657where
658    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
659{
660    const MAX_DOCUMENTS: usize = 1024;
661
662    type Client = Client<T>;
663
664    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
665        Self::new(client.clone(), model, ndims.unwrap_or_default())
666    }
667
668    fn ndims(&self) -> usize {
669        self.ndims
670    }
671
672    async fn embed_texts(
673        &self,
674        documents: impl IntoIterator<Item = String>,
675    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
676        let documents = documents.into_iter().collect::<Vec<_>>();
677
678        let body = serde_json::json!({
679            "model": self.model,
680            "input": documents,
681        });
682
683        let body = serde_json::to_vec(&body)?;
684
685        let req = self
686            .client
687            .post("v1/embeddings")?
688            .body(body)
689            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
690
691        let response = self.client.send(req).await?;
692
693        if response.status().is_success() {
694            let body: Vec<u8> = response.into_body().await?;
695            let body: ApiResponse<openai::EmbeddingResponse> = serde_json::from_slice(&body)?;
696
697            match body {
698                ApiResponse::Ok(response) => {
699                    tracing::info!(target: "rig",
700                        "Llamafile embedding token usage: {:?}",
701                        response.usage
702                    );
703
704                    if response.data.len() != documents.len() {
705                        return Err(EmbeddingError::ResponseError(
706                            "Response data length does not match input length".into(),
707                        ));
708                    }
709
710                    Ok(response
711                        .data
712                        .into_iter()
713                        .zip(documents.into_iter())
714                        .map(|(embedding, document)| embeddings::Embedding {
715                            document,
716                            vec: embedding
717                                .embedding
718                                .into_iter()
719                                .filter_map(|n| n.as_f64())
720                                .collect(),
721                        })
722                        .collect())
723                }
724                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
725            }
726        } else {
727            let text = http_client::text(response).await?;
728            Err(EmbeddingError::ProviderError(text))
729        }
730    }
731}
732
733// ================================================================
734// Tests
735// ================================================================
736#[cfg(test)]
737mod tests {
738    use super::*;
739    use crate::client::Nothing;
740    use crate::completion::Document;
741    use std::collections::HashMap;
742
743    #[test]
744    fn test_client_initialization() {
745        let _client =
746            crate::providers::llamafile::Client::new(Nothing).expect("Client::new() failed");
747        let _client_from_builder = crate::providers::llamafile::Client::builder()
748            .api_key(Nothing)
749            .build()
750            .expect("Client::builder() failed");
751    }
752
753    #[test]
754    fn test_client_from_url() {
755        let _client = crate::providers::llamafile::Client::from_url("http://localhost:8080");
756    }
757
758    #[test]
759    fn test_completion_request_conversion() {
760        use crate::OneOrMany;
761        use crate::completion::Message as CompletionMessage;
762        use crate::message::{Text, UserContent};
763
764        let completion_request = CompletionRequest {
765            model: None,
766            preamble: Some("You are a helpful assistant.".to_string()),
767            chat_history: OneOrMany::one(CompletionMessage::User {
768                content: OneOrMany::one(UserContent::Text(Text {
769                    text: "Hello!".to_string(),
770                })),
771            }),
772            documents: vec![],
773            tools: vec![],
774            temperature: Some(0.7),
775            max_tokens: Some(256),
776            tool_choice: None,
777            additional_params: None,
778            output_schema: None,
779        };
780
781        let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request))
782            .expect("Failed to create request");
783
784        assert_eq!(request.model, LLAMA_CPP);
785        assert_eq!(request.messages.len(), 2); // system + user
786        assert_eq!(
787            request.messages[0]["content"],
788            "You are a helpful assistant."
789        );
790        assert_eq!(request.messages[1]["content"], "Hello!");
791        assert_eq!(request.temperature, Some(0.7));
792        assert_eq!(request.max_tokens, Some(256));
793    }
794
795    #[test]
796    fn test_completion_request_flattens_text_only_document_arrays() {
797        use crate::OneOrMany;
798        use crate::completion::Message as CompletionMessage;
799        use crate::message::{Text, UserContent};
800
801        let completion_request = CompletionRequest {
802            model: None,
803            preamble: None,
804            chat_history: OneOrMany::one(CompletionMessage::User {
805                content: OneOrMany::one(UserContent::Text(Text {
806                    text: "What does glarb-glarb mean?".to_string(),
807                })),
808            }),
809            documents: vec![
810                Document {
811                    id: "doc-1".into(),
812                    text: "Definition of flurbo: a green alien.".into(),
813                    additional_props: HashMap::new(),
814                },
815                Document {
816                    id: "doc-2".into(),
817                    text: "Definition of glarb-glarb: an ancient farming tool.".into(),
818                    additional_props: HashMap::new(),
819                },
820            ],
821            tools: vec![],
822            temperature: None,
823            max_tokens: None,
824            tool_choice: None,
825            additional_params: None,
826            output_schema: None,
827        };
828
829        let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request))
830            .expect("Failed to create request");
831
832        assert_eq!(request.messages.len(), 2);
833        assert!(request.messages[0]["content"].is_string());
834        let documents = request.messages[0]["content"]
835            .as_str()
836            .expect("documents should serialize as a string");
837        assert!(documents.contains("Definition of flurbo"));
838        assert!(documents.contains("Definition of glarb-glarb"));
839    }
840
841    #[test]
842    fn test_llamafile_message_value_flattens_assistant_text_content() {
843        let message = openai::Message::Assistant {
844            content: vec![openai::AssistantContent::Text {
845                text: "Tool returned the answer.".into(),
846            }],
847            reasoning: None,
848            refusal: None,
849            audio: None,
850            name: None,
851            tool_calls: vec![openai::ToolCall {
852                id: "call_1".into(),
853                r#type: openai::ToolType::Function,
854                function: openai::Function {
855                    name: "weather".into(),
856                    arguments: serde_json::json!({"city": "London"}),
857                },
858            }],
859        };
860
861        let value = llamafile_message_value(message).expect("message conversion should succeed");
862
863        assert_eq!(value["role"], "assistant");
864        assert_eq!(value["content"], "Tool returned the answer.");
865        assert!(value["tool_calls"].is_array());
866    }
867}