rig/providers/huggingface/
completion.rs

1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7    OneOrMany,
8    completion::{self, CompletionError, CompletionRequest},
9    json_utils,
10    message::{self},
11    one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22    Ok(T),
23    Err(Value),
24}
25
26// ================================================================
27// Huggingface Completion API
28// ================================================================
29
30// Conversational LLMs
31/// `google/gemma-2-2b-it` completion model
32pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
34pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35/// `PowerInfer/SmallThinker-3B-Preview` completion model
36pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37/// `Qwen/Qwen2.5-7B-Instruct` completion model
38pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
40pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42// Conversational VLMs
43
44/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
45pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46/// `Qwen/QVQ-72B-Preview` visual-language completion model
47pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51    name: String,
52    #[serde(
53        serialize_with = "json_utils::stringified_json::serialize",
54        deserialize_with = "deserialize_arguments"
55    )]
56    pub arguments: serde_json::Value,
57}
58
59fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
60where
61    D: Deserializer<'de>,
62{
63    let value = Value::deserialize(deserializer)?;
64
65    match value {
66        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
67        other => Ok(other),
68    }
69}
70
71impl From<Function> for message::ToolFunction {
72    fn from(value: Function) -> Self {
73        message::ToolFunction {
74            name: value.name,
75            arguments: value.arguments,
76        }
77    }
78}
79
80#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
81#[serde(rename_all = "lowercase")]
82pub enum ToolType {
83    #[default]
84    Function,
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone)]
88pub struct ToolDefinition {
89    pub r#type: String,
90    pub function: completion::ToolDefinition,
91}
92
93impl From<completion::ToolDefinition> for ToolDefinition {
94    fn from(tool: completion::ToolDefinition) -> Self {
95        Self {
96            r#type: "function".into(),
97            function: tool,
98        }
99    }
100}
101
102#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
103pub struct ToolCall {
104    pub id: String,
105    pub r#type: ToolType,
106    pub function: Function,
107}
108
109impl From<ToolCall> for message::ToolCall {
110    fn from(value: ToolCall) -> Self {
111        message::ToolCall {
112            id: value.id,
113            call_id: None,
114            function: value.function.into(),
115        }
116    }
117}
118
119impl From<message::ToolCall> for ToolCall {
120    fn from(value: message::ToolCall) -> Self {
121        ToolCall {
122            id: value.id,
123            r#type: ToolType::Function,
124            function: Function {
125                name: value.function.name,
126                arguments: value.function.arguments,
127            },
128        }
129    }
130}
131
132#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
133pub struct ImageUrl {
134    url: String,
135}
136
137#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
138#[serde(tag = "type", rename_all = "lowercase")]
139pub enum UserContent {
140    Text {
141        text: String,
142    },
143    #[serde(rename = "image_url")]
144    ImageUrl {
145        image_url: ImageUrl,
146    },
147}
148
149impl FromStr for UserContent {
150    type Err = Infallible;
151
152    fn from_str(s: &str) -> Result<Self, Self::Err> {
153        Ok(UserContent::Text {
154            text: s.to_string(),
155        })
156    }
157}
158
159#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
160#[serde(tag = "type", rename_all = "lowercase")]
161pub enum AssistantContent {
162    Text { text: String },
163}
164
165impl FromStr for AssistantContent {
166    type Err = Infallible;
167
168    fn from_str(s: &str) -> Result<Self, Self::Err> {
169        Ok(AssistantContent::Text {
170            text: s.to_string(),
171        })
172    }
173}
174
175#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
176#[serde(tag = "type", rename_all = "lowercase")]
177pub enum SystemContent {
178    Text { text: String },
179}
180
181impl FromStr for SystemContent {
182    type Err = Infallible;
183
184    fn from_str(s: &str) -> Result<Self, Self::Err> {
185        Ok(SystemContent::Text {
186            text: s.to_string(),
187        })
188    }
189}
190
191impl From<UserContent> for message::UserContent {
192    fn from(value: UserContent) -> Self {
193        match value {
194            UserContent::Text { text } => message::UserContent::text(text),
195            UserContent::ImageUrl { image_url } => {
196                message::UserContent::image_url(image_url.url, None, None)
197            }
198        }
199    }
200}
201
202impl TryFrom<message::UserContent> for UserContent {
203    type Error = message::MessageError;
204
205    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
206        match content {
207            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
208            message::UserContent::Document(message::Document {
209                data: message::DocumentSourceKind::Raw(raw),
210                ..
211            }) => {
212                let text = String::from_utf8_lossy(raw.as_slice()).into();
213                Ok(UserContent::Text { text })
214            }
215            message::UserContent::Document(message::Document {
216                data:
217                    message::DocumentSourceKind::Base64(text)
218                    | message::DocumentSourceKind::String(text),
219                ..
220            }) => Ok(UserContent::Text { text }),
221            message::UserContent::Image(message::Image { data, .. }) => match data {
222                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
223                    image_url: ImageUrl { url },
224                }),
225                _ => Err(message::MessageError::ConversionError(
226                    "Huggingface only supports images as urls".into(),
227                )),
228            },
229            _ => Err(message::MessageError::ConversionError(
230                "Huggingface only supports text and images".into(),
231            )),
232        }
233    }
234}
235
236#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
237#[serde(tag = "role", rename_all = "lowercase")]
238pub enum Message {
239    System {
240        #[serde(deserialize_with = "string_or_one_or_many")]
241        content: OneOrMany<SystemContent>,
242    },
243    User {
244        #[serde(deserialize_with = "string_or_one_or_many")]
245        content: OneOrMany<UserContent>,
246    },
247    Assistant {
248        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
249        content: Vec<AssistantContent>,
250        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
251        tool_calls: Vec<ToolCall>,
252    },
253    #[serde(rename = "tool", alias = "Tool")]
254    ToolResult {
255        name: String,
256        #[serde(skip_serializing_if = "Option::is_none")]
257        arguments: Option<serde_json::Value>,
258        #[serde(
259            deserialize_with = "string_or_one_or_many",
260            serialize_with = "serialize_tool_content"
261        )]
262        content: OneOrMany<String>,
263    },
264}
265
266fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
267where
268    S: Serializer,
269{
270    // OpenAI-compatible APIs expect tool content as a string, not an array
271    let joined = content
272        .iter()
273        .map(String::as_str)
274        .collect::<Vec<_>>()
275        .join("\n");
276    serializer.serialize_str(&joined)
277}
278
279impl Message {
280    pub fn system(content: &str) -> Self {
281        Message::System {
282            content: OneOrMany::one(SystemContent::Text {
283                text: content.to_string(),
284            }),
285        }
286    }
287}
288
289impl TryFrom<message::Message> for Vec<Message> {
290    type Error = message::MessageError;
291
292    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
293        match message {
294            message::Message::User { content } => {
295                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
296                    .into_iter()
297                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
298
299                if !tool_results.is_empty() {
300                    tool_results
301                        .into_iter()
302                        .map(|content| match content {
303                            message::UserContent::ToolResult(message::ToolResult {
304                                id,
305                                content,
306                                ..
307                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
308                                name: id,
309                                arguments: None,
310                                content: content.try_map(|content| match content {
311                                    message::ToolResultContent::Text(message::Text { text }) => {
312                                        Ok(text)
313                                    }
314                                    _ => Err(message::MessageError::ConversionError(
315                                        "Tool result content does not support non-text".into(),
316                                    )),
317                                })?,
318                            }),
319                            _ => unreachable!(),
320                        })
321                        .collect::<Result<Vec<_>, _>>()
322                } else {
323                    let other_content = OneOrMany::many(other_content).expect(
324                        "There must be other content here if there were no tool result content",
325                    );
326
327                    Ok(vec![Message::User {
328                        content: other_content.try_map(|content| match content {
329                            message::UserContent::Text(text) => {
330                                Ok(UserContent::Text { text: text.text })
331                            }
332                            message::UserContent::Image(image) => {
333                                let url = image.try_into_url()?;
334
335                                Ok(UserContent::ImageUrl {
336                                    image_url: ImageUrl { url },
337                                })
338                            }
339                            message::UserContent::Document(message::Document {
340                                data: message::DocumentSourceKind::Raw(raw), ..
341                            }) => {
342                                let text = String::from_utf8_lossy(raw.as_slice()).into();
343                                Ok(UserContent::Text { text })
344                            }
345                            message::UserContent::Document(message::Document {
346                                data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
347                            }) => {
348                                Ok(UserContent::Text { text })
349                            }
350                            _ => Err(message::MessageError::ConversionError(
351                                "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
352                            )),
353                        })?,
354                    }])
355                }
356            }
357            message::Message::Assistant { content, .. } => {
358                let (text_content, tool_calls) = content.into_iter().fold(
359                    (Vec::new(), Vec::new()),
360                    |(mut texts, mut tools), content| {
361                        match content {
362                            message::AssistantContent::Text(text) => texts.push(text),
363                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
364                            message::AssistantContent::Reasoning(_) => {
365                                unimplemented!("Reasoning is not supported on HuggingFace via Rig");
366                            }
367                            message::AssistantContent::Image(_) => {
368                                unimplemented!(
369                                    "Image content is not supported on HuggingFace via Rig"
370                                );
371                            }
372                        }
373                        (texts, tools)
374                    },
375                );
376
377                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
378                //  so either `content` or `tool_calls` will have some content.
379                Ok(vec![Message::Assistant {
380                    content: text_content
381                        .into_iter()
382                        .map(|content| AssistantContent::Text { text: content.text })
383                        .collect::<Vec<_>>(),
384                    tool_calls: tool_calls
385                        .into_iter()
386                        .map(|tool_call| tool_call.into())
387                        .collect::<Vec<_>>(),
388                }])
389            }
390        }
391    }
392}
393
394impl TryFrom<Message> for message::Message {
395    type Error = message::MessageError;
396
397    fn try_from(message: Message) -> Result<Self, Self::Error> {
398        Ok(match message {
399            Message::User { content, .. } => message::Message::User {
400                content: content.map(|content| content.into()),
401            },
402            Message::Assistant {
403                content,
404                tool_calls,
405                ..
406            } => {
407                let mut content = content
408                    .into_iter()
409                    .map(|content| match content {
410                        AssistantContent::Text { text } => message::AssistantContent::text(text),
411                    })
412                    .collect::<Vec<_>>();
413
414                content.extend(
415                    tool_calls
416                        .into_iter()
417                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
418                        .collect::<Result<Vec<_>, _>>()?,
419                );
420
421                message::Message::Assistant {
422                    id: None,
423                    content: OneOrMany::many(content).map_err(|_| {
424                        message::MessageError::ConversionError(
425                            "Neither `content` nor `tool_calls` was provided to the Message"
426                                .to_owned(),
427                        )
428                    })?,
429                }
430            }
431
432            Message::ToolResult { name, content, .. } => message::Message::User {
433                content: OneOrMany::one(message::UserContent::tool_result(
434                    name,
435                    content.map(message::ToolResultContent::text),
436                )),
437            },
438
439            // System messages should get stripped out when converting message's, this is just a
440            // stop gap to avoid obnoxious error handling or panic occurring.
441            Message::System { content, .. } => message::Message::User {
442                content: content.map(|c| match c {
443                    SystemContent::Text { text } => message::UserContent::text(text),
444                }),
445            },
446        })
447    }
448}
449
450#[derive(Clone, Debug, Deserialize, Serialize)]
451pub struct Choice {
452    pub finish_reason: String,
453    pub index: usize,
454    #[serde(default)]
455    pub logprobs: serde_json::Value,
456    pub message: Message,
457}
458
459#[derive(Debug, Deserialize, Clone, Serialize)]
460pub struct Usage {
461    pub completion_tokens: i32,
462    pub prompt_tokens: i32,
463    pub total_tokens: i32,
464}
465
466impl GetTokenUsage for Usage {
467    fn token_usage(&self) -> Option<crate::completion::Usage> {
468        let mut usage = crate::completion::Usage::new();
469        usage.input_tokens = self.prompt_tokens as u64;
470        usage.output_tokens = self.completion_tokens as u64;
471        usage.total_tokens = self.total_tokens as u64;
472
473        Some(usage)
474    }
475}
476
477#[derive(Clone, Debug, Deserialize, Serialize)]
478pub struct CompletionResponse {
479    pub created: i32,
480    pub id: String,
481    pub model: String,
482    pub choices: Vec<Choice>,
483    #[serde(default, deserialize_with = "default_string_on_null")]
484    pub system_fingerprint: String,
485    pub usage: Usage,
486}
487
488impl crate::telemetry::ProviderResponseExt for CompletionResponse {
489    type OutputMessage = Choice;
490    type Usage = Usage;
491
492    fn get_response_id(&self) -> Option<String> {
493        Some(self.id.clone())
494    }
495
496    fn get_response_model_name(&self) -> Option<String> {
497        Some(self.model.clone())
498    }
499
500    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
501        self.choices.clone()
502    }
503
504    fn get_text_response(&self) -> Option<String> {
505        let text_response = self
506            .choices
507            .iter()
508            .filter_map(|x| {
509                let Message::User { ref content } = x.message else {
510                    return None;
511                };
512
513                let text = content
514                    .iter()
515                    .filter_map(|x| {
516                        if let UserContent::Text { text } = x {
517                            Some(text.clone())
518                        } else {
519                            None
520                        }
521                    })
522                    .collect::<Vec<String>>();
523
524                if text.is_empty() {
525                    None
526                } else {
527                    Some(text.join("\n"))
528                }
529            })
530            .collect::<Vec<String>>()
531            .join("\n");
532
533        if text_response.is_empty() {
534            None
535        } else {
536            Some(text_response)
537        }
538    }
539
540    fn get_usage(&self) -> Option<Self::Usage> {
541        Some(self.usage.clone())
542    }
543}
544
545fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
546where
547    D: Deserializer<'de>,
548{
549    match Option::<String>::deserialize(deserializer)? {
550        Some(value) => Ok(value),      // Use provided value
551        None => Ok(String::default()), // Use `Default` implementation
552    }
553}
554
555impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
556    type Error = CompletionError;
557
558    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
559        let choice = response.choices.first().ok_or_else(|| {
560            CompletionError::ResponseError("Response contained no choices".to_owned())
561        })?;
562
563        let content = match &choice.message {
564            Message::Assistant {
565                content,
566                tool_calls,
567                ..
568            } => {
569                let mut content = content
570                    .iter()
571                    .map(|c| match c {
572                        AssistantContent::Text { text } => message::AssistantContent::text(text),
573                    })
574                    .collect::<Vec<_>>();
575
576                content.extend(
577                    tool_calls
578                        .iter()
579                        .map(|call| {
580                            completion::AssistantContent::tool_call(
581                                &call.id,
582                                &call.function.name,
583                                call.function.arguments.clone(),
584                            )
585                        })
586                        .collect::<Vec<_>>(),
587                );
588                Ok(content)
589            }
590            _ => Err(CompletionError::ResponseError(
591                "Response did not contain a valid message or tool call".into(),
592            )),
593        }?;
594
595        let choice = OneOrMany::many(content).map_err(|_| {
596            CompletionError::ResponseError(
597                "Response contained no message or tool call (empty)".to_owned(),
598            )
599        })?;
600
601        let usage = completion::Usage {
602            input_tokens: response.usage.prompt_tokens as u64,
603            output_tokens: response.usage.completion_tokens as u64,
604            total_tokens: response.usage.total_tokens as u64,
605        };
606
607        Ok(completion::CompletionResponse {
608            choice,
609            usage,
610            raw_response: response,
611        })
612    }
613}
614
615#[derive(Debug, Serialize, Deserialize)]
616pub(super) struct HuggingfaceCompletionRequest {
617    model: String,
618    pub messages: Vec<Message>,
619    #[serde(flatten, skip_serializing_if = "Option::is_none")]
620    temperature: Option<f64>,
621    #[serde(skip_serializing_if = "Vec::is_empty")]
622    tools: Vec<ToolDefinition>,
623    #[serde(flatten, skip_serializing_if = "Option::is_none")]
624    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
625    #[serde(flatten, skip_serializing_if = "Option::is_none")]
626    pub additional_params: Option<serde_json::Value>,
627}
628
629impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
630    type Error = CompletionError;
631
632    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
633        let mut full_history: Vec<Message> = match &req.preamble {
634            Some(preamble) => vec![Message::system(preamble)],
635            None => vec![],
636        };
637        if let Some(docs) = req.normalized_documents() {
638            let docs: Vec<Message> = docs.try_into()?;
639            full_history.extend(docs);
640        }
641
642        let chat_history: Vec<Message> = req
643            .chat_history
644            .clone()
645            .into_iter()
646            .map(|message| message.try_into())
647            .collect::<Result<Vec<Vec<Message>>, _>>()?
648            .into_iter()
649            .flatten()
650            .collect();
651
652        full_history.extend(chat_history);
653
654        let tool_choice = req
655            .tool_choice
656            .clone()
657            .map(crate::providers::openai::completion::ToolChoice::try_from)
658            .transpose()?;
659
660        Ok(Self {
661            model: model.to_string(),
662            messages: full_history,
663            temperature: req.temperature,
664            tools: req
665                .tools
666                .clone()
667                .into_iter()
668                .map(ToolDefinition::from)
669                .collect::<Vec<_>>(),
670            tool_choice,
671            additional_params: req.additional_params,
672        })
673    }
674}
675
676#[derive(Clone)]
677pub struct CompletionModel<T = reqwest::Client> {
678    pub(crate) client: Client<T>,
679    /// Name of the model (e.g: google/gemma-2-2b-it)
680    pub model: String,
681}
682
683impl<T> CompletionModel<T> {
684    pub fn new(client: Client<T>, model: &str) -> Self {
685        Self {
686            client,
687            model: model.to_string(),
688        }
689    }
690}
691
692impl<T> completion::CompletionModel for CompletionModel<T>
693where
694    T: HttpClientExt + Clone + 'static,
695{
696    type Response = CompletionResponse;
697    type StreamingResponse = StreamingCompletionResponse;
698
699    type Client = Client<T>;
700
701    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
702        Self::new(client.clone(), &model.into())
703    }
704
705    #[cfg_attr(feature = "worker", worker::send)]
706    async fn completion(
707        &self,
708        completion_request: CompletionRequest,
709    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
710        let span = if tracing::Span::current().is_disabled() {
711            info_span!(
712                target: "rig::completions",
713                "chat",
714                gen_ai.operation.name = "chat",
715                gen_ai.provider.name = "huggingface",
716                gen_ai.request.model = self.model,
717                gen_ai.system_instructions = &completion_request.preamble,
718                gen_ai.response.id = tracing::field::Empty,
719                gen_ai.response.model = tracing::field::Empty,
720                gen_ai.usage.output_tokens = tracing::field::Empty,
721                gen_ai.usage.input_tokens = tracing::field::Empty,
722            )
723        } else {
724            tracing::Span::current()
725        };
726
727        let model = self.client.subprovider().model_identifier(&self.model);
728        let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
729
730        if enabled!(Level::TRACE) {
731            tracing::trace!(
732                target: "rig::completions",
733                "Huggingface completion request: {}",
734                serde_json::to_string_pretty(&request)?
735            );
736        }
737
738        let request = serde_json::to_vec(&request)?;
739
740        let path = self.client.subprovider().completion_endpoint(&self.model);
741        let request = self
742            .client
743            .post(&path)?
744            .header("Content-Type", "application/json")
745            .body(request)
746            .map_err(|e| CompletionError::HttpError(e.into()))?;
747
748        async move {
749            let response = self.client.send(request).await?;
750
751            if response.status().is_success() {
752                let bytes: Vec<u8> = response.into_body().await?;
753                let text = String::from_utf8_lossy(&bytes);
754
755                tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
756
757                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
758                    ApiResponse::Ok(response) => {
759                        if enabled!(Level::TRACE) {
760                            tracing::trace!(
761                                target: "rig::completions",
762                                "Huggingface completion response: {}",
763                                serde_json::to_string_pretty(&response)?
764                            );
765                        }
766
767                        let span = tracing::Span::current();
768                        span.record_token_usage(&response.usage);
769                        span.record_response_metadata(&response);
770
771                        response.try_into()
772                    }
773                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
774                }
775            } else {
776                let status = response.status();
777                let text: Vec<u8> = response.into_body().await?;
778                let text: String = String::from_utf8_lossy(&text).into();
779
780                Err(CompletionError::ProviderError(format!(
781                    "{}: {}",
782                    status, text
783                )))
784            }
785        }
786        .instrument(span)
787        .await
788    }
789
790    #[cfg_attr(feature = "worker", worker::send)]
791    async fn stream(
792        &self,
793        request: CompletionRequest,
794    ) -> Result<
795        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
796        CompletionError,
797    > {
798        CompletionModel::stream(self, request).await
799    }
800}
801
802#[cfg(test)]
803mod tests {
804    use super::*;
805    use serde_path_to_error::deserialize;
806
807    #[test]
808    fn test_deserialize_message() {
809        let assistant_message_json = r#"
810        {
811            "role": "assistant",
812            "content": "\n\nHello there, how may I assist you today?"
813        }
814        "#;
815
816        let assistant_message_json2 = r#"
817        {
818            "role": "assistant",
819            "content": [
820                {
821                    "type": "text",
822                    "text": "\n\nHello there, how may I assist you today?"
823                }
824            ],
825            "tool_calls": null
826        }
827        "#;
828
829        let assistant_message_json3 = r#"
830        {
831            "role": "assistant",
832            "tool_calls": [
833                {
834                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
835                    "type": "function",
836                    "function": {
837                        "name": "subtract",
838                        "arguments": {"x": 2, "y": 5}
839                    }
840                }
841            ],
842            "content": null,
843            "refusal": null
844        }
845        "#;
846
847        let user_message_json = r#"
848        {
849            "role": "user",
850            "content": [
851                {
852                    "type": "text",
853                    "text": "What's in this image?"
854                },
855                {
856                    "type": "image_url",
857                    "image_url": {
858                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
859                    }
860                }
861            ]
862        }
863        "#;
864
865        let assistant_message: Message = {
866            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
867            deserialize(jd).unwrap_or_else(|err| {
868                panic!(
869                    "Deserialization error at {} ({}:{}): {}",
870                    err.path(),
871                    err.inner().line(),
872                    err.inner().column(),
873                    err
874                );
875            })
876        };
877
878        let assistant_message2: Message = {
879            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
880            deserialize(jd).unwrap_or_else(|err| {
881                panic!(
882                    "Deserialization error at {} ({}:{}): {}",
883                    err.path(),
884                    err.inner().line(),
885                    err.inner().column(),
886                    err
887                );
888            })
889        };
890
891        let assistant_message3: Message = {
892            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
893                &mut serde_json::Deserializer::from_str(assistant_message_json3);
894            deserialize(jd).unwrap_or_else(|err| {
895                panic!(
896                    "Deserialization error at {} ({}:{}): {}",
897                    err.path(),
898                    err.inner().line(),
899                    err.inner().column(),
900                    err
901                );
902            })
903        };
904
905        let user_message: Message = {
906            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
907            deserialize(jd).unwrap_or_else(|err| {
908                panic!(
909                    "Deserialization error at {} ({}:{}): {}",
910                    err.path(),
911                    err.inner().line(),
912                    err.inner().column(),
913                    err
914                );
915            })
916        };
917
918        match assistant_message {
919            Message::Assistant { content, .. } => {
920                assert_eq!(
921                    content[0],
922                    AssistantContent::Text {
923                        text: "\n\nHello there, how may I assist you today?".to_string()
924                    }
925                );
926            }
927            _ => panic!("Expected assistant message"),
928        }
929
930        match assistant_message2 {
931            Message::Assistant {
932                content,
933                tool_calls,
934                ..
935            } => {
936                assert_eq!(
937                    content[0],
938                    AssistantContent::Text {
939                        text: "\n\nHello there, how may I assist you today?".to_string()
940                    }
941                );
942
943                assert_eq!(tool_calls, vec![]);
944            }
945            _ => panic!("Expected assistant message"),
946        }
947
948        match assistant_message3 {
949            Message::Assistant {
950                content,
951                tool_calls,
952                ..
953            } => {
954                assert!(content.is_empty());
955                assert_eq!(
956                    tool_calls[0],
957                    ToolCall {
958                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
959                        r#type: ToolType::Function,
960                        function: Function {
961                            name: "subtract".to_string(),
962                            arguments: serde_json::json!({"x": 2, "y": 5}),
963                        },
964                    }
965                );
966            }
967            _ => panic!("Expected assistant message"),
968        }
969
970        match user_message {
971            Message::User { content, .. } => {
972                let (first, second) = {
973                    let mut iter = content.into_iter();
974                    (iter.next().unwrap(), iter.next().unwrap())
975                };
976                assert_eq!(
977                    first,
978                    UserContent::Text {
979                        text: "What's in this image?".to_string()
980                    }
981                );
982                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
983            }
984            _ => panic!("Expected user message"),
985        }
986    }
987
988    #[test]
989    fn test_message_to_message_conversion() {
990        let user_message = message::Message::User {
991            content: OneOrMany::one(message::UserContent::text("Hello")),
992        };
993
994        let assistant_message = message::Message::Assistant {
995            id: None,
996            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
997        };
998
999        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1000        let converted_assistant_message: Vec<Message> =
1001            assistant_message.clone().try_into().unwrap();
1002
1003        match converted_user_message[0].clone() {
1004            Message::User { content, .. } => {
1005                assert_eq!(
1006                    content.first(),
1007                    UserContent::Text {
1008                        text: "Hello".to_string()
1009                    }
1010                );
1011            }
1012            _ => panic!("Expected user message"),
1013        }
1014
1015        match converted_assistant_message[0].clone() {
1016            Message::Assistant { content, .. } => {
1017                assert_eq!(
1018                    content[0],
1019                    AssistantContent::Text {
1020                        text: "Hi there!".to_string()
1021                    }
1022                );
1023            }
1024            _ => panic!("Expected assistant message"),
1025        }
1026
1027        let original_user_message: message::Message =
1028            converted_user_message[0].clone().try_into().unwrap();
1029        let original_assistant_message: message::Message =
1030            converted_assistant_message[0].clone().try_into().unwrap();
1031
1032        assert_eq!(original_user_message, user_message);
1033        assert_eq!(original_assistant_message, assistant_message);
1034    }
1035
1036    #[test]
1037    fn test_message_from_message_conversion() {
1038        let user_message = Message::User {
1039            content: OneOrMany::one(UserContent::Text {
1040                text: "Hello".to_string(),
1041            }),
1042        };
1043
1044        let assistant_message = Message::Assistant {
1045            content: vec![AssistantContent::Text {
1046                text: "Hi there!".to_string(),
1047            }],
1048            tool_calls: vec![],
1049        };
1050
1051        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1052        let converted_assistant_message: message::Message =
1053            assistant_message.clone().try_into().unwrap();
1054
1055        match converted_user_message.clone() {
1056            message::Message::User { content } => {
1057                assert_eq!(content.first(), message::UserContent::text("Hello"));
1058            }
1059            _ => panic!("Expected user message"),
1060        }
1061
1062        match converted_assistant_message.clone() {
1063            message::Message::Assistant { content, .. } => {
1064                assert_eq!(
1065                    content.first(),
1066                    message::AssistantContent::text("Hi there!")
1067                );
1068            }
1069            _ => panic!("Expected assistant message"),
1070        }
1071
1072        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1073        let original_assistant_message: Vec<Message> =
1074            converted_assistant_message.try_into().unwrap();
1075
1076        assert_eq!(original_user_message[0], user_message);
1077        assert_eq!(original_assistant_message[0], assistant_message);
1078    }
1079
1080    #[test]
1081    fn test_responses() {
1082        let fireworks_response_json = r#"
1083        {
1084            "choices": [
1085                {
1086                    "finish_reason": "tool_calls",
1087                    "index": 0,
1088                    "message": {
1089                        "role": "assistant",
1090                        "tool_calls": [
1091                            {
1092                                "function": {
1093                                "arguments": "{\"x\": 2, \"y\": 5}",
1094                                "name": "subtract"
1095                                },
1096                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1097                                "index": 0,
1098                                "type": "function"
1099                            }
1100                        ]
1101                    }
1102                }
1103            ],
1104            "created": 1740704000,
1105            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1106            "model": "accounts/fireworks/models/deepseek-v3",
1107            "object": "chat.completion",
1108            "usage": {
1109                "completion_tokens": 26,
1110                "prompt_tokens": 248,
1111                "total_tokens": 274
1112            }
1113        }
1114        "#;
1115
1116        let novita_response_json = r#"
1117        {
1118            "choices": [
1119                {
1120                    "finish_reason": "tool_calls",
1121                    "index": 0,
1122                    "logprobs": null,
1123                    "message": {
1124                        "audio": null,
1125                        "content": null,
1126                        "function_call": null,
1127                        "reasoning_content": null,
1128                        "refusal": null,
1129                        "role": "assistant",
1130                        "tool_calls": [
1131                            {
1132                                "function": {
1133                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1134                                    "name": "subtract"
1135                                },
1136                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1137                                "type": "function"
1138                            }
1139                        ]
1140                    },
1141                    "stop_reason": 128008
1142                }
1143            ],
1144            "created": 1740704592,
1145            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1146            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1147            "object": "chat.completion",
1148            "prompt_logprobs": null,
1149            "service_tier": null,
1150            "system_fingerprint": null,
1151            "usage": {
1152                "completion_tokens": 28,
1153                "completion_tokens_details": null,
1154                "prompt_tokens": 335,
1155                "prompt_tokens_details": null,
1156                "total_tokens": 363
1157            }
1158        }
1159        "#;
1160
1161        let _firework_response: CompletionResponse = {
1162            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1163            deserialize(jd).unwrap_or_else(|err| {
1164                panic!(
1165                    "Deserialization error at {} ({}:{}): {}",
1166                    err.path(),
1167                    err.inner().line(),
1168                    err.inner().column(),
1169                    err
1170                );
1171            })
1172        };
1173
1174        let _novita_response: CompletionResponse = {
1175            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1176            deserialize(jd).unwrap_or_else(|err| {
1177                panic!(
1178                    "Deserialization error at {} ({}:{}): {}",
1179                    err.path(),
1180                    err.inner().line(),
1181                    err.inner().column(),
1182                    err
1183                );
1184            })
1185        };
1186    }
1187}