1use async_stream::stream;
2use futures::StreamExt;
3use serde::Deserialize;
4use serde_json::json;
5
6use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
7use super::decoders::sse::from_response as sse_from_response;
8use crate::completion::{CompletionError, CompletionRequest};
9use crate::json_utils::merge_inplace;
10use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
11
12#[derive(Debug, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum StreamingEvent {
15 MessageStart {
16 message: MessageStart,
17 },
18 ContentBlockStart {
19 index: usize,
20 content_block: Content,
21 },
22 ContentBlockDelta {
23 index: usize,
24 delta: ContentDelta,
25 },
26 ContentBlockStop {
27 index: usize,
28 },
29 MessageDelta {
30 delta: MessageDelta,
31 usage: PartialUsage,
32 },
33 MessageStop,
34 Ping,
35 #[serde(other)]
36 Unknown,
37}
38
39#[derive(Debug, Deserialize)]
40pub struct MessageStart {
41 pub id: String,
42 pub role: String,
43 pub content: Vec<Content>,
44 pub model: String,
45 pub stop_reason: Option<String>,
46 pub stop_sequence: Option<String>,
47 pub usage: Usage,
48}
49
50#[derive(Debug, Deserialize)]
51#[serde(tag = "type", rename_all = "snake_case")]
52pub enum ContentDelta {
53 TextDelta { text: String },
54 InputJsonDelta { partial_json: String },
55}
56
57#[derive(Debug, Deserialize)]
58pub struct MessageDelta {
59 pub stop_reason: Option<String>,
60 pub stop_sequence: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
64pub struct PartialUsage {
65 pub output_tokens: usize,
66 #[serde(default)]
67 pub input_tokens: Option<usize>,
68}
69
70#[derive(Default)]
71struct ToolCallState {
72 name: String,
73 id: String,
74 input_json: String,
75}
76
77impl StreamingCompletionModel for CompletionModel {
78 async fn stream(
79 &self,
80 completion_request: CompletionRequest,
81 ) -> Result<StreamingResult, CompletionError> {
82 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
83 tokens
84 } else if let Some(tokens) = self.default_max_tokens {
85 tokens
86 } else {
87 return Err(CompletionError::RequestError(
88 "`max_tokens` must be set for Anthropic".into(),
89 ));
90 };
91
92 let mut full_history = vec![];
93 if let Some(docs) = completion_request.normalized_documents() {
94 full_history.push(docs);
95 }
96 full_history.extend(completion_request.chat_history);
97
98 let full_history = full_history
99 .into_iter()
100 .map(Message::try_from)
101 .collect::<Result<Vec<Message>, _>>()?;
102
103 let mut request = json!({
104 "model": self.model,
105 "messages": full_history,
106 "max_tokens": max_tokens,
107 "system": completion_request.preamble.unwrap_or("".to_string()),
108 "stream": true,
109 });
110
111 if let Some(temperature) = completion_request.temperature {
112 merge_inplace(&mut request, json!({ "temperature": temperature }));
113 }
114
115 if !completion_request.tools.is_empty() {
116 merge_inplace(
117 &mut request,
118 json!({
119 "tools": completion_request
120 .tools
121 .into_iter()
122 .map(|tool| ToolDefinition {
123 name: tool.name,
124 description: Some(tool.description),
125 input_schema: tool.parameters,
126 })
127 .collect::<Vec<_>>(),
128 "tool_choice": ToolChoice::Auto,
129 }),
130 );
131 }
132
133 if let Some(ref params) = completion_request.additional_params {
134 merge_inplace(&mut request, params.clone())
135 }
136
137 let response = self
138 .client
139 .post("/v1/messages")
140 .json(&request)
141 .send()
142 .await?;
143
144 if !response.status().is_success() {
145 return Err(CompletionError::ProviderError(response.text().await?));
146 }
147
148 let sse_stream = sse_from_response(response);
150
151 Ok(Box::pin(stream! {
152 let mut current_tool_call: Option<ToolCallState> = None;
153 let mut sse_stream = Box::pin(sse_stream);
154
155 while let Some(sse_result) = sse_stream.next().await {
156 match sse_result {
157 Ok(sse) => {
158 match serde_json::from_str::<StreamingEvent>(&sse.data) {
160 Ok(event) => {
161 if let Some(result) = handle_event(&event, &mut current_tool_call) {
162 yield result;
163 }
164 },
165 Err(e) => {
166 if !sse.data.trim().is_empty() {
167 yield Err(CompletionError::ResponseError(
168 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
169 ));
170 }
171 }
172 }
173 },
174 Err(e) => {
175 yield Err(CompletionError::ResponseError(format!("SSE Error: {}", e)));
176 break;
177 }
178 }
179 }
180 }))
181 }
182}
183
184fn handle_event(
185 event: &StreamingEvent,
186 current_tool_call: &mut Option<ToolCallState>,
187) -> Option<Result<StreamingChoice, CompletionError>> {
188 match event {
189 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
190 ContentDelta::TextDelta { text } => {
191 if current_tool_call.is_none() {
192 return Some(Ok(StreamingChoice::Message(text.clone())));
193 }
194 None
195 }
196 ContentDelta::InputJsonDelta { partial_json } => {
197 if let Some(ref mut tool_call) = current_tool_call {
198 tool_call.input_json.push_str(partial_json);
199 }
200 None
201 }
202 },
203 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
204 Content::ToolUse { id, name, .. } => {
205 *current_tool_call = Some(ToolCallState {
206 name: name.clone(),
207 id: id.clone(),
208 input_json: String::new(),
209 });
210 None
211 }
212 _ => None,
214 },
215 StreamingEvent::ContentBlockStop { .. } => {
216 if let Some(tool_call) = current_tool_call.take() {
217 let json_str = if tool_call.input_json.is_empty() {
218 "{}"
219 } else {
220 &tool_call.input_json
221 };
222 match serde_json::from_str(json_str) {
223 Ok(json_value) => Some(Ok(StreamingChoice::ToolCall(
224 tool_call.name,
225 tool_call.id,
226 json_value,
227 ))),
228 Err(e) => Some(Err(CompletionError::from(e))),
229 }
230 } else {
231 None
232 }
233 }
234 StreamingEvent::MessageStart { .. }
236 | StreamingEvent::MessageDelta { .. }
237 | StreamingEvent::MessageStop
238 | StreamingEvent::Ping
239 | StreamingEvent::Unknown => None,
240 }
241}