infernum_core/
streaming.rs1use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::Stream;
7use serde::{Deserialize, Serialize};
8
9use crate::error::Result;
10use crate::response::TokenInfo;
11use crate::types::{FinishReason, ModelId, RequestId, Usage};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct StreamChunk {
16 pub request_id: RequestId,
18
19 pub model: ModelId,
21
22 pub choices: Vec<StreamChoice>,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub usage: Option<Usage>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct StreamChoice {
33 pub index: u32,
35
36 pub delta: StreamDelta,
38
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub finish_reason: Option<FinishReason>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct StreamDelta {
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub content: Option<String>,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub token: Option<TokenInfo>,
54}
55
56impl StreamDelta {
57 #[must_use]
59 pub fn text(content: impl Into<String>) -> Self {
60 Self {
61 content: Some(content.into()),
62 token: None,
63 }
64 }
65
66 #[must_use]
68 pub fn token(token: TokenInfo) -> Self {
69 Self {
70 content: Some(token.text.clone()),
71 token: Some(token),
72 }
73 }
74
75 #[must_use]
77 pub fn empty() -> Self {
78 Self {
79 content: None,
80 token: None,
81 }
82 }
83}
84
85pub struct TokenStream {
87 inner: Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>,
88}
89
90impl TokenStream {
91 pub fn new<S>(stream: S) -> Self
93 where
94 S: Stream<Item = Result<StreamChunk>> + Send + 'static,
95 {
96 Self {
97 inner: Box::pin(stream),
98 }
99 }
100
101 #[must_use]
103 pub fn empty() -> Self {
104 Self::new(futures::stream::empty())
105 }
106
107 #[must_use]
109 pub fn once(chunk: StreamChunk) -> Self {
110 Self::new(futures::stream::once(async move { Ok(chunk) }))
111 }
112
113 pub async fn collect(self) -> Result<Vec<StreamChunk>> {
119 use futures::StreamExt;
120 let mut chunks = Vec::new();
121 let mut stream = self;
122 while let Some(result) = stream.next().await {
123 chunks.push(result?);
124 }
125 Ok(chunks)
126 }
127
128 pub async fn collect_text(self) -> Result<String> {
134 let chunks = self.collect().await?;
135 let mut text = String::new();
136 for chunk in chunks {
137 for choice in chunk.choices {
138 if let Some(content) = choice.delta.content {
139 text.push_str(&content);
140 }
141 }
142 }
143 Ok(text)
144 }
145}
146
147impl Stream for TokenStream {
148 type Item = Result<StreamChunk>;
149
150 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
151 self.inner.as_mut().poll_next(cx)
152 }
153}
154
155#[derive(Debug)]
157pub struct StreamChunkBuilder {
158 request_id: RequestId,
159 model: ModelId,
160 choices: Vec<StreamChoice>,
161 usage: Option<Usage>,
162}
163
164impl StreamChunkBuilder {
165 #[must_use]
167 pub fn new(request_id: RequestId, model: ModelId) -> Self {
168 Self {
169 request_id,
170 model,
171 choices: Vec::new(),
172 usage: None,
173 }
174 }
175
176 #[must_use]
178 pub fn text(mut self, index: u32, content: impl Into<String>) -> Self {
179 self.choices.push(StreamChoice {
180 index,
181 delta: StreamDelta::text(content),
182 finish_reason: None,
183 });
184 self
185 }
186
187 #[must_use]
189 pub fn finish(mut self, index: u32, reason: FinishReason) -> Self {
190 self.choices.push(StreamChoice {
191 index,
192 delta: StreamDelta::empty(),
193 finish_reason: Some(reason),
194 });
195 self
196 }
197
198 #[must_use]
200 pub fn usage(mut self, usage: Usage) -> Self {
201 self.usage = Some(usage);
202 self
203 }
204
205 #[must_use]
207 pub fn build(self) -> StreamChunk {
208 StreamChunk {
209 request_id: self.request_id,
210 model: self.model,
211 choices: self.choices,
212 usage: self.usage,
213 }
214 }
215}