1use async_stream::stream;
2use futures::StreamExt;
3use serde::Deserialize;
4use serde_json::json;
5
6use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
7use super::decoders::sse::from_response as sse_from_response;
8use crate::completion::{CompletionError, CompletionRequest};
9use crate::json_utils::merge_inplace;
10use crate::streaming;
11use crate::streaming::{RawStreamingChoice, StreamingResult};
12
13#[derive(Debug, Deserialize)]
14#[serde(tag = "type", rename_all = "snake_case")]
15pub enum StreamingEvent {
16 MessageStart {
17 message: MessageStart,
18 },
19 ContentBlockStart {
20 index: usize,
21 content_block: Content,
22 },
23 ContentBlockDelta {
24 index: usize,
25 delta: ContentDelta,
26 },
27 ContentBlockStop {
28 index: usize,
29 },
30 MessageDelta {
31 delta: MessageDelta,
32 usage: PartialUsage,
33 },
34 MessageStop,
35 Ping,
36 #[serde(other)]
37 Unknown,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct MessageStart {
42 pub id: String,
43 pub role: String,
44 pub content: Vec<Content>,
45 pub model: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51#[derive(Debug, Deserialize)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ContentDelta {
54 TextDelta { text: String },
55 InputJsonDelta { partial_json: String },
56}
57
58#[derive(Debug, Deserialize)]
59pub struct MessageDelta {
60 pub stop_reason: Option<String>,
61 pub stop_sequence: Option<String>,
62}
63
64#[derive(Debug, Deserialize, Clone)]
65pub struct PartialUsage {
66 pub output_tokens: usize,
67 #[serde(default)]
68 pub input_tokens: Option<usize>,
69}
70
71#[derive(Default)]
72struct ToolCallState {
73 name: String,
74 id: String,
75 input_json: String,
76}
77
78#[derive(Clone)]
79pub struct StreamingCompletionResponse {
80 pub usage: PartialUsage,
81}
82
83impl CompletionModel {
84 pub(crate) async fn stream(
85 &self,
86 completion_request: CompletionRequest,
87 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
88 {
89 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
90 tokens
91 } else if let Some(tokens) = self.default_max_tokens {
92 tokens
93 } else {
94 return Err(CompletionError::RequestError(
95 "`max_tokens` must be set for Anthropic".into(),
96 ));
97 };
98
99 let mut full_history = vec![];
100 if let Some(docs) = completion_request.normalized_documents() {
101 full_history.push(docs);
102 }
103 full_history.extend(completion_request.chat_history);
104
105 let full_history = full_history
106 .into_iter()
107 .map(Message::try_from)
108 .collect::<Result<Vec<Message>, _>>()?;
109
110 let mut request = json!({
111 "model": self.model,
112 "messages": full_history,
113 "max_tokens": max_tokens,
114 "system": completion_request.preamble.unwrap_or("".to_string()),
115 "stream": true,
116 });
117
118 if let Some(temperature) = completion_request.temperature {
119 merge_inplace(&mut request, json!({ "temperature": temperature }));
120 }
121
122 if !completion_request.tools.is_empty() {
123 merge_inplace(
124 &mut request,
125 json!({
126 "tools": completion_request
127 .tools
128 .into_iter()
129 .map(|tool| ToolDefinition {
130 name: tool.name,
131 description: Some(tool.description),
132 input_schema: tool.parameters,
133 })
134 .collect::<Vec<_>>(),
135 "tool_choice": ToolChoice::Auto,
136 }),
137 );
138 }
139
140 if let Some(ref params) = completion_request.additional_params {
141 merge_inplace(&mut request, params.clone())
142 }
143
144 let response = self
145 .client
146 .post("/v1/messages")
147 .json(&request)
148 .send()
149 .await?;
150
151 if !response.status().is_success() {
152 return Err(CompletionError::ProviderError(response.text().await?));
153 }
154
155 let sse_stream = sse_from_response(response);
157
158 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
159 let mut current_tool_call: Option<ToolCallState> = None;
160 let mut sse_stream = Box::pin(sse_stream);
161 let mut input_tokens = 0;
162
163 while let Some(sse_result) = sse_stream.next().await {
164 match sse_result {
165 Ok(sse) => {
166 match serde_json::from_str::<StreamingEvent>(&sse.data) {
168 Ok(event) => {
169 match &event {
170 StreamingEvent::MessageStart { message } => {
171 input_tokens = message.usage.input_tokens;
172 },
173 StreamingEvent::MessageDelta { delta, usage } => {
174 if delta.stop_reason.is_some() {
175
176 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
177 usage: PartialUsage {
178 output_tokens: usage.output_tokens,
179 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
180 }
181 }))
182 }
183 }
184 _ => {}
185 }
186
187 if let Some(result) = handle_event(&event, &mut current_tool_call) {
188 yield result;
189 }
190 },
191 Err(e) => {
192 if !sse.data.trim().is_empty() {
193 yield Err(CompletionError::ResponseError(
194 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
195 ));
196 }
197 }
198 }
199 },
200 Err(e) => {
201 yield Err(CompletionError::ResponseError(format!("SSE Error: {e}")));
202 break;
203 }
204 }
205 }
206 });
207
208 Ok(streaming::StreamingCompletionResponse::stream(stream))
209 }
210}
211
212fn handle_event(
213 event: &StreamingEvent,
214 current_tool_call: &mut Option<ToolCallState>,
215) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
216 match event {
217 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
218 ContentDelta::TextDelta { text } => {
219 if current_tool_call.is_none() {
220 return Some(Ok(RawStreamingChoice::Message(text.clone())));
221 }
222 None
223 }
224 ContentDelta::InputJsonDelta { partial_json } => {
225 if let Some(ref mut tool_call) = current_tool_call {
226 tool_call.input_json.push_str(partial_json);
227 }
228 None
229 }
230 },
231 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
232 Content::ToolUse { id, name, .. } => {
233 *current_tool_call = Some(ToolCallState {
234 name: name.clone(),
235 id: id.clone(),
236 input_json: String::new(),
237 });
238 None
239 }
240 _ => None,
242 },
243 StreamingEvent::ContentBlockStop { .. } => {
244 if let Some(tool_call) = current_tool_call.take() {
245 let json_str = if tool_call.input_json.is_empty() {
246 "{}"
247 } else {
248 &tool_call.input_json
249 };
250 match serde_json::from_str(json_str) {
251 Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
252 name: tool_call.name,
253 id: tool_call.id,
254 arguments: json_value,
255 })),
256 Err(e) => Some(Err(CompletionError::from(e))),
257 }
258 } else {
259 None
260 }
261 }
262 StreamingEvent::MessageStart { .. }
264 | StreamingEvent::MessageDelta { .. }
265 | StreamingEvent::MessageStop
266 | StreamingEvent::Ping
267 | StreamingEvent::Unknown => None,
268 }
269}