rig/providers/anthropic/
streaming.rs1use async_stream::stream;
2use futures::StreamExt;
3use serde::Deserialize;
4use serde_json::json;
5
6use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
7use crate::completion::{CompletionError, CompletionRequest};
8use crate::json_utils::merge_inplace;
9use crate::message::MessageError;
10use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
11
12#[derive(Debug, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum StreamingEvent {
15 MessageStart {
16 message: MessageStart,
17 },
18 ContentBlockStart {
19 index: usize,
20 content_block: Content,
21 },
22 ContentBlockDelta {
23 index: usize,
24 delta: ContentDelta,
25 },
26 ContentBlockStop {
27 index: usize,
28 },
29 MessageDelta {
30 delta: MessageDelta,
31 usage: Usage,
32 },
33 MessageStop,
34 Ping,
35}
36
37#[derive(Debug, Deserialize)]
38pub struct MessageStart {
39 pub id: String,
40 pub role: String,
41 pub content: Vec<Content>,
42 pub model: String,
43 pub stop_reason: Option<String>,
44 pub stop_sequence: Option<String>,
45 pub usage: Usage,
46}
47
48#[derive(Debug, Deserialize)]
49#[serde(tag = "type", rename_all = "snake_case")]
50pub enum ContentDelta {
51 TextDelta { text: String },
52 InputJsonDelta { partial_json: String },
53}
54
55#[derive(Debug, Deserialize)]
56pub struct MessageDelta {
57 pub stop_reason: Option<String>,
58 pub stop_sequence: Option<String>,
59}
60
61#[derive(Default)]
62struct ToolCallState {
63 name: String,
64 id: String,
65 input_json: String,
66}
67
68impl StreamingCompletionModel for CompletionModel {
69 async fn stream(
70 &self,
71 completion_request: CompletionRequest,
72 ) -> Result<StreamingResult, CompletionError> {
73 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
74 tokens
75 } else if let Some(tokens) = self.default_max_tokens {
76 tokens
77 } else {
78 return Err(CompletionError::RequestError(
79 "`max_tokens` must be set for Anthropic".into(),
80 ));
81 };
82
83 let prompt_message: Message = completion_request
84 .prompt_with_context()
85 .try_into()
86 .map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
87
88 let mut messages = completion_request
89 .chat_history
90 .into_iter()
91 .map(|message| {
92 message
93 .try_into()
94 .map_err(|e: MessageError| CompletionError::RequestError(e.into()))
95 })
96 .collect::<Result<Vec<Message>, _>>()?;
97
98 messages.push(prompt_message);
99
100 let mut request = json!({
101 "model": self.model,
102 "messages": messages,
103 "max_tokens": max_tokens,
104 "system": completion_request.preamble.unwrap_or("".to_string()),
105 "stream": true,
106 });
107
108 if let Some(temperature) = completion_request.temperature {
109 merge_inplace(&mut request, json!({ "temperature": temperature }));
110 }
111
112 if !completion_request.tools.is_empty() {
113 merge_inplace(
114 &mut request,
115 json!({
116 "tools": completion_request
117 .tools
118 .into_iter()
119 .map(|tool| ToolDefinition {
120 name: tool.name,
121 description: Some(tool.description),
122 input_schema: tool.parameters,
123 })
124 .collect::<Vec<_>>(),
125 "tool_choice": ToolChoice::Auto,
126 }),
127 );
128 }
129
130 if let Some(ref params) = completion_request.additional_params {
131 merge_inplace(&mut request, params.clone())
132 }
133
134 let response = self
135 .client
136 .post("/v1/messages")
137 .json(&request)
138 .send()
139 .await?;
140
141 if !response.status().is_success() {
142 return Err(CompletionError::ProviderError(response.text().await?));
143 }
144
145 Ok(Box::pin(stream! {
146 let mut current_tool_call: Option<ToolCallState> = None;
147 let mut stream = response.bytes_stream();
148
149 while let Some(chunk_result) = stream.next().await {
150 let chunk = match chunk_result {
151 Ok(c) => c,
152 Err(e) => {
153 yield Err(CompletionError::from(e));
154 break;
155 }
156 };
157
158 let text = match String::from_utf8(chunk.to_vec()) {
159 Ok(t) => t,
160 Err(e) => {
161 yield Err(CompletionError::ResponseError(e.to_string()));
162 break;
163 }
164 };
165
166 for line in text.lines() {
167 if let Some(data) = line.strip_prefix("data: ") {
168 if let Ok(event) = serde_json::from_str::<StreamingEvent>(data) {
169 match event {
170 StreamingEvent::ContentBlockDelta { delta, .. } => {
171 match delta {
172 ContentDelta::TextDelta { text } => {
173 if current_tool_call.is_none() {
174 yield Ok(StreamingChoice::Message(text));
175 }
176 }
177 ContentDelta::InputJsonDelta { partial_json } => {
178 if let Some(ref mut tool_call) = current_tool_call {
179 tool_call.input_json.push_str(&partial_json);
180 }
181 }
182 }
183 }
184 StreamingEvent::ContentBlockStart {
185 content_block: Content::ToolUse { id, name, .. },
186 ..
187 } => {
188 current_tool_call = Some(ToolCallState {
189 name,
190 id,
191 input_json: String::new(),
192 });
193 }
194 StreamingEvent::ContentBlockStop { .. } => {
195 if let Some(tool_call) = current_tool_call.take() {
196 let json_str = if tool_call.input_json.is_empty() {
197 "{}"
198 } else {
199 &tool_call.input_json
200 };
201 match serde_json::from_str(json_str) {
202 Ok(json_value) => {
203 yield Ok(StreamingChoice::ToolCall(
204 tool_call.name,
205 tool_call.id,
206 json_value,
207 ));
208 }
209 Err(e) => {
210 yield Err(CompletionError::from(e));
211 }
212 }
213 }
214 },
215 _ => {}
216 }
217 }
218 }
219 }
220 }
221 }))
222 }
223}