rig/providers/huggingface/
streaming.rs

1use super::completion::CompletionModel;
2use crate::completion::{CompletionError, CompletionRequest};
3use crate::json_utils;
4use crate::json_utils::merge_inplace;
5use crate::providers::openai::send_compatible_streaming_request;
6use crate::streaming::{StreamingCompletionModel, StreamingResult};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use std::convert::Infallible;
10use std::str::FromStr;
11
12#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
13#[serde(rename_all = "lowercase", tag = "type")]
14/// Represents the content sent back in the StreamDelta for an Assistant
15enum AssistantContent {
16    Text { text: String },
17}
18
19// Ensure that string contents can be serialized correctly
20impl FromStr for AssistantContent {
21    type Err = Infallible;
22
23    fn from_str(s: &str) -> Result<Self, Self::Err> {
24        Ok(AssistantContent::Text {
25            text: s.to_string(),
26        })
27    }
28}
29
30#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
31#[serde(rename_all = "lowercase", tag = "role")]
32enum StreamDelta {
33    Assistant {
34        #[serde(deserialize_with = "json_utils::string_or_vec")]
35        content: Vec<AssistantContent>,
36    },
37}
38
39#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
40struct StreamingChoice {
41    index: usize,
42    delta: StreamDelta,
43    logprobs: Value,
44    finish_reason: Option<String>,
45}
46
47#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
48struct CompletionChunk {
49    id: String,
50    created: i32,
51    model: String,
52    #[serde(default)]
53    system_fingerprint: String,
54    choices: Vec<StreamingChoice>,
55}
56
57impl StreamingCompletionModel for CompletionModel {
58    async fn stream(
59        &self,
60        completion_request: CompletionRequest,
61    ) -> Result<StreamingResult, CompletionError> {
62        let mut request = self.create_request_body(&completion_request)?;
63
64        // Enable streaming
65        merge_inplace(&mut request, json!({"stream": true}));
66
67        if let Some(ref params) = completion_request.additional_params {
68            merge_inplace(&mut request, params.clone());
69        }
70
71        // HF Inference API uses the model in the path even though its specified in the request body
72        let path = self.client.sub_provider.completion_endpoint(&self.model);
73
74        let builder = self.client.post(&path).json(&request);
75
76        send_compatible_streaming_request(builder).await
77    }
78}