1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5
6use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
7use super::decoders::sse::from_response as sse_from_response;
8use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
9use crate::json_utils::merge_inplace;
10use crate::streaming;
11use crate::streaming::{RawStreamingChoice, StreamingResult};
12
13#[derive(Debug, Deserialize)]
14#[serde(tag = "type", rename_all = "snake_case")]
15pub enum StreamingEvent {
16 MessageStart {
17 message: MessageStart,
18 },
19 ContentBlockStart {
20 index: usize,
21 content_block: Content,
22 },
23 ContentBlockDelta {
24 index: usize,
25 delta: ContentDelta,
26 },
27 ContentBlockStop {
28 index: usize,
29 },
30 MessageDelta {
31 delta: MessageDelta,
32 usage: PartialUsage,
33 },
34 MessageStop,
35 Ping,
36 #[serde(other)]
37 Unknown,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct MessageStart {
42 pub id: String,
43 pub role: String,
44 pub content: Vec<Content>,
45 pub model: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51#[derive(Debug, Deserialize)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ContentDelta {
54 TextDelta { text: String },
55 InputJsonDelta { partial_json: String },
56}
57
58#[derive(Debug, Deserialize)]
59pub struct MessageDelta {
60 pub stop_reason: Option<String>,
61 pub stop_sequence: Option<String>,
62}
63
64#[derive(Debug, Deserialize, Clone, Serialize)]
65pub struct PartialUsage {
66 pub output_tokens: usize,
67 #[serde(default)]
68 pub input_tokens: Option<usize>,
69}
70
71#[derive(Default)]
72struct ToolCallState {
73 name: String,
74 id: String,
75 input_json: String,
76}
77
78#[derive(Clone, Deserialize, Serialize)]
79pub struct StreamingCompletionResponse {
80 pub usage: PartialUsage,
81}
82
83impl GetTokenUsage for StreamingCompletionResponse {
84 fn token_usage(&self) -> Option<crate::completion::Usage> {
85 let mut usage = crate::completion::Usage::new();
86 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
87 usage.output_tokens = self.usage.output_tokens as u64;
88 usage.total_tokens =
89 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
90
91 Some(usage)
92 }
93}
94
95impl CompletionModel {
96 pub(crate) async fn stream(
97 &self,
98 completion_request: CompletionRequest,
99 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
100 {
101 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
102 tokens
103 } else if let Some(tokens) = self.default_max_tokens {
104 tokens
105 } else {
106 return Err(CompletionError::RequestError(
107 "`max_tokens` must be set for Anthropic".into(),
108 ));
109 };
110
111 let mut full_history = vec![];
112 if let Some(docs) = completion_request.normalized_documents() {
113 full_history.push(docs);
114 }
115 full_history.extend(completion_request.chat_history);
116
117 let full_history = full_history
118 .into_iter()
119 .map(Message::try_from)
120 .collect::<Result<Vec<Message>, _>>()?;
121
122 let mut request = json!({
123 "model": self.model,
124 "messages": full_history,
125 "max_tokens": max_tokens,
126 "system": completion_request.preamble.unwrap_or("".to_string()),
127 "stream": true,
128 });
129
130 if let Some(temperature) = completion_request.temperature {
131 merge_inplace(&mut request, json!({ "temperature": temperature }));
132 }
133
134 if !completion_request.tools.is_empty() {
135 merge_inplace(
136 &mut request,
137 json!({
138 "tools": completion_request
139 .tools
140 .into_iter()
141 .map(|tool| ToolDefinition {
142 name: tool.name,
143 description: Some(tool.description),
144 input_schema: tool.parameters,
145 })
146 .collect::<Vec<_>>(),
147 "tool_choice": ToolChoice::Auto,
148 }),
149 );
150 }
151
152 if let Some(ref params) = completion_request.additional_params {
153 merge_inplace(&mut request, params.clone())
154 }
155
156 let response = self
157 .client
158 .post("/v1/messages")
159 .json(&request)
160 .send()
161 .await?;
162
163 if !response.status().is_success() {
164 return Err(CompletionError::ProviderError(response.text().await?));
165 }
166
167 let sse_stream = sse_from_response(response);
169
170 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
171 let mut current_tool_call: Option<ToolCallState> = None;
172 let mut sse_stream = Box::pin(sse_stream);
173 let mut input_tokens = 0;
174
175 while let Some(sse_result) = sse_stream.next().await {
176 match sse_result {
177 Ok(sse) => {
178 match serde_json::from_str::<StreamingEvent>(&sse.data) {
180 Ok(event) => {
181 match &event {
182 StreamingEvent::MessageStart { message } => {
183 input_tokens = message.usage.input_tokens;
184 },
185 StreamingEvent::MessageDelta { delta, usage } => {
186 if delta.stop_reason.is_some() {
187
188 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
189 usage: PartialUsage {
190 output_tokens: usage.output_tokens,
191 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
192 }
193 }))
194 }
195 }
196 _ => {}
197 }
198
199 if let Some(result) = handle_event(&event, &mut current_tool_call) {
200 yield result;
201 }
202 },
203 Err(e) => {
204 if !sse.data.trim().is_empty() {
205 yield Err(CompletionError::ResponseError(
206 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
207 ));
208 }
209 }
210 }
211 },
212 Err(e) => {
213 yield Err(CompletionError::ResponseError(format!("SSE Error: {e}")));
214 break;
215 }
216 }
217 }
218 });
219
220 Ok(streaming::StreamingCompletionResponse::stream(stream))
221 }
222}
223
224fn handle_event(
225 event: &StreamingEvent,
226 current_tool_call: &mut Option<ToolCallState>,
227) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
228 match event {
229 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
230 ContentDelta::TextDelta { text } => {
231 if current_tool_call.is_none() {
232 return Some(Ok(RawStreamingChoice::Message(text.clone())));
233 }
234 None
235 }
236 ContentDelta::InputJsonDelta { partial_json } => {
237 if let Some(tool_call) = current_tool_call {
238 tool_call.input_json.push_str(partial_json);
239 }
240 None
241 }
242 },
243 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
244 Content::ToolUse { id, name, .. } => {
245 *current_tool_call = Some(ToolCallState {
246 name: name.clone(),
247 id: id.clone(),
248 input_json: String::new(),
249 });
250 None
251 }
252 _ => None,
254 },
255 StreamingEvent::ContentBlockStop { .. } => {
256 if let Some(tool_call) = current_tool_call.take() {
257 let json_str = if tool_call.input_json.is_empty() {
258 "{}"
259 } else {
260 &tool_call.input_json
261 };
262 match serde_json::from_str(json_str) {
263 Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
264 name: tool_call.name,
265 id: tool_call.id,
266 arguments: json_value,
267 call_id: None,
268 })),
269 Err(e) => Some(Err(CompletionError::from(e))),
270 }
271 } else {
272 None
273 }
274 }
275 StreamingEvent::MessageStart { .. }
277 | StreamingEvent::MessageDelta { .. }
278 | StreamingEvent::MessageStop
279 | StreamingEvent::Ping
280 | StreamingEvent::Unknown => None,
281 }
282}