rig/providers/huggingface/
streaming.rs1use super::completion::CompletionModel;
2use crate::completion::{CompletionError, CompletionRequest};
3use crate::json_utils;
4use crate::json_utils::merge_inplace;
5use crate::providers::openai::send_compatible_streaming_request;
6use crate::streaming::{StreamingCompletionModel, StreamingResult};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use std::convert::Infallible;
10use std::str::FromStr;
11
12#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
13#[serde(rename_all = "lowercase", tag = "type")]
14enum AssistantContent {
16 Text { text: String },
17}
18
19impl FromStr for AssistantContent {
21 type Err = Infallible;
22
23 fn from_str(s: &str) -> Result<Self, Self::Err> {
24 Ok(AssistantContent::Text {
25 text: s.to_string(),
26 })
27 }
28}
29
30#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
31#[serde(rename_all = "lowercase", tag = "role")]
32enum StreamDelta {
33 Assistant {
34 #[serde(deserialize_with = "json_utils::string_or_vec")]
35 content: Vec<AssistantContent>,
36 },
37}
38
39#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
40struct StreamingChoice {
41 index: usize,
42 delta: StreamDelta,
43 logprobs: Value,
44 finish_reason: Option<String>,
45}
46
47#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
48struct CompletionChunk {
49 id: String,
50 created: i32,
51 model: String,
52 #[serde(default)]
53 system_fingerprint: String,
54 choices: Vec<StreamingChoice>,
55}
56
57impl StreamingCompletionModel for CompletionModel {
58 async fn stream(
59 &self,
60 completion_request: CompletionRequest,
61 ) -> Result<StreamingResult, CompletionError> {
62 let mut request = self.create_request_body(&completion_request)?;
63
64 merge_inplace(&mut request, json!({"stream": true}));
66
67 if let Some(ref params) = completion_request.additional_params {
68 merge_inplace(&mut request, params.clone());
69 }
70
71 let path = self.client.sub_provider.completion_endpoint(&self.model);
73
74 let builder = self.client.post(&path).json(&request);
75
76 send_compatible_streaming_request(builder).await
77 }
78}