rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::Deserialize;
4use serde_json::json;
5
6use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
7use crate::completion::{CompletionError, CompletionRequest};
8use crate::json_utils::merge_inplace;
9use crate::message::MessageError;
10use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
11
12#[derive(Debug, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum StreamingEvent {
15    MessageStart {
16        message: MessageStart,
17    },
18    ContentBlockStart {
19        index: usize,
20        content_block: Content,
21    },
22    ContentBlockDelta {
23        index: usize,
24        delta: ContentDelta,
25    },
26    ContentBlockStop {
27        index: usize,
28    },
29    MessageDelta {
30        delta: MessageDelta,
31        usage: Usage,
32    },
33    MessageStop,
34    Ping,
35}
36
37#[derive(Debug, Deserialize)]
38pub struct MessageStart {
39    pub id: String,
40    pub role: String,
41    pub content: Vec<Content>,
42    pub model: String,
43    pub stop_reason: Option<String>,
44    pub stop_sequence: Option<String>,
45    pub usage: Usage,
46}
47
48#[derive(Debug, Deserialize)]
49#[serde(tag = "type", rename_all = "snake_case")]
50pub enum ContentDelta {
51    TextDelta { text: String },
52    InputJsonDelta { partial_json: String },
53}
54
55#[derive(Debug, Deserialize)]
56pub struct MessageDelta {
57    pub stop_reason: Option<String>,
58    pub stop_sequence: Option<String>,
59}
60
61#[derive(Default)]
62struct ToolCallState {
63    name: String,
64    id: String,
65    input_json: String,
66}
67
68impl StreamingCompletionModel for CompletionModel {
69    async fn stream(
70        &self,
71        completion_request: CompletionRequest,
72    ) -> Result<StreamingResult, CompletionError> {
73        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
74            tokens
75        } else if let Some(tokens) = self.default_max_tokens {
76            tokens
77        } else {
78            return Err(CompletionError::RequestError(
79                "`max_tokens` must be set for Anthropic".into(),
80            ));
81        };
82
83        let prompt_message: Message = completion_request
84            .prompt_with_context()
85            .try_into()
86            .map_err(|e: MessageError| CompletionError::RequestError(e.into()))?;
87
88        let mut messages = completion_request
89            .chat_history
90            .into_iter()
91            .map(|message| {
92                message
93                    .try_into()
94                    .map_err(|e: MessageError| CompletionError::RequestError(e.into()))
95            })
96            .collect::<Result<Vec<Message>, _>>()?;
97
98        messages.push(prompt_message);
99
100        let mut request = json!({
101            "model": self.model,
102            "messages": messages,
103            "max_tokens": max_tokens,
104            "system": completion_request.preamble.unwrap_or("".to_string()),
105            "stream": true,
106        });
107
108        if let Some(temperature) = completion_request.temperature {
109            merge_inplace(&mut request, json!({ "temperature": temperature }));
110        }
111
112        if !completion_request.tools.is_empty() {
113            merge_inplace(
114                &mut request,
115                json!({
116                    "tools": completion_request
117                        .tools
118                        .into_iter()
119                        .map(|tool| ToolDefinition {
120                            name: tool.name,
121                            description: Some(tool.description),
122                            input_schema: tool.parameters,
123                        })
124                        .collect::<Vec<_>>(),
125                    "tool_choice": ToolChoice::Auto,
126                }),
127            );
128        }
129
130        if let Some(ref params) = completion_request.additional_params {
131            merge_inplace(&mut request, params.clone())
132        }
133
134        let response = self
135            .client
136            .post("/v1/messages")
137            .json(&request)
138            .send()
139            .await?;
140
141        if !response.status().is_success() {
142            return Err(CompletionError::ProviderError(response.text().await?));
143        }
144
145        Ok(Box::pin(stream! {
146            let mut current_tool_call: Option<ToolCallState> = None;
147            let mut stream = response.bytes_stream();
148
149            while let Some(chunk_result) = stream.next().await {
150                let chunk = match chunk_result {
151                    Ok(c) => c,
152                    Err(e) => {
153                        yield Err(CompletionError::from(e));
154                        break;
155                    }
156                };
157
158                let text = match String::from_utf8(chunk.to_vec()) {
159                    Ok(t) => t,
160                    Err(e) => {
161                        yield Err(CompletionError::ResponseError(e.to_string()));
162                        break;
163                    }
164                };
165
166                for line in text.lines() {
167                    if let Some(data) = line.strip_prefix("data: ") {
168                        if let Ok(event) = serde_json::from_str::<StreamingEvent>(data) {
169                            match event {
170                                StreamingEvent::ContentBlockDelta { delta, .. } => {
171                                    match delta {
172                                        ContentDelta::TextDelta { text } => {
173                                            if current_tool_call.is_none() {
174                                                yield Ok(StreamingChoice::Message(text));
175                                            }
176                                        }
177                                        ContentDelta::InputJsonDelta { partial_json } => {
178                                            if let Some(ref mut tool_call) = current_tool_call {
179                                                tool_call.input_json.push_str(&partial_json);
180                                            }
181                                        }
182                                    }
183                                }
184                                StreamingEvent::ContentBlockStart {
185                                    content_block: Content::ToolUse { id, name, .. },
186                                    ..
187                                } => {
188                                    current_tool_call = Some(ToolCallState {
189                                        name,
190                                        id,
191                                        input_json: String::new(),
192                                    });
193                                }
194                                StreamingEvent::ContentBlockStop { .. } => {
195                                    if let Some(tool_call) = current_tool_call.take() {
196                                        let json_str = if tool_call.input_json.is_empty() {
197                                            "{}"
198                                        } else {
199                                            &tool_call.input_json
200                                        };
201                                        match serde_json::from_str(json_str) {
202                                            Ok(json_value) => {
203                                                yield Ok(StreamingChoice::ToolCall(
204                                                    tool_call.name,
205                                                    tool_call.id,
206                                                    json_value,
207                                                ));
208                                            }
209                                            Err(e) => {
210                                                yield Err(CompletionError::from(e));
211                                            }
212                                        }
213                                    }
214                                },
215                                _ => {}
216                            }
217                        }
218                    }
219                }
220            }
221        }))
222    }
223}