syncable_cli/bedrock/
streaming.rs

1use super::types::completion_request::AwsCompletionRequest;
2use super::{completion::CompletionModel, types::errors::AwsSdkConverseStreamError};
3use async_stream::stream;
4use aws_sdk_bedrockruntime::types as aws_bedrock;
5use rig::completion::GetTokenUsage;
6use rig::streaming::StreamingCompletionResponse;
7use rig::{
8    completion::CompletionError,
9    streaming::{RawStreamingChoice, RawStreamingToolCall},
10};
11use serde::{Deserialize, Serialize};
12
13#[derive(Clone, Deserialize, Serialize)]
14pub struct BedrockStreamingResponse {
15    pub usage: Option<BedrockUsage>,
16}
17
18#[derive(Clone, Deserialize, Serialize)]
19pub struct BedrockUsage {
20    pub input_tokens: i32,
21    pub output_tokens: i32,
22    pub total_tokens: i32,
23}
24
25impl GetTokenUsage for BedrockStreamingResponse {
26    fn token_usage(&self) -> Option<rig::completion::Usage> {
27        self.usage.as_ref().map(|u| rig::completion::Usage {
28            input_tokens: u.input_tokens as u64,
29            output_tokens: u.output_tokens as u64,
30            total_tokens: u.total_tokens as u64,
31        })
32    }
33}
34
35#[derive(Default)]
36struct ToolCallState {
37    name: String,
38    id: String,
39    input_json: String,
40}
41
42#[derive(Default)]
43struct ReasoningState {
44    content: String,
45    signature: Option<String>,
46}
47
48impl CompletionModel {
49    pub(crate) async fn stream(
50        &self,
51        completion_request: rig::completion::CompletionRequest,
52    ) -> Result<StreamingCompletionResponse<BedrockStreamingResponse>, CompletionError> {
53        let request = AwsCompletionRequest(completion_request);
54
55        let mut converse_builder = self
56            .client
57            .get_inner()
58            .await
59            .converse_stream()
60            .model_id(self.model.as_str());
61
62        let tool_config = request.tools_config()?;
63        let prompt_with_history = request.messages()?;
64        converse_builder = converse_builder
65            .set_additional_model_request_fields(request.additional_params())
66            .set_inference_config(request.inference_config())
67            .set_tool_config(tool_config)
68            .set_system(request.system_prompt())
69            .set_messages(Some(prompt_with_history));
70
71        let response = converse_builder.send().await.map_err(|sdk_error| {
72            Into::<CompletionError>::into(AwsSdkConverseStreamError(sdk_error))
73        })?;
74
75        let stream = Box::pin(stream! {
76            let mut current_tool_call: Option<ToolCallState> = None;
77            let mut current_reasoning: Option<ReasoningState> = None;
78            let mut stream = response.stream;
79            while let Ok(Some(output)) = stream.recv().await {
80                match output {
81                    aws_bedrock::ConverseStreamOutput::ContentBlockDelta(event) => {
82                        let delta = event.delta.ok_or(CompletionError::ProviderError("The delta for a content block is missing".into()))?;
83                        match delta {
84                            aws_bedrock::ContentBlockDelta::Text(text) => {
85                                if current_tool_call.is_none() {
86                                    yield Ok(RawStreamingChoice::Message(text))
87                                }
88                            },
89                            aws_bedrock::ContentBlockDelta::ToolUse(tool) => {
90                                if let Some(ref mut tool_call) = current_tool_call {
91                                    let delta = tool.input().to_string();
92                                    tool_call.input_json.push_str(&delta);
93
94                                    // Emit the delta so UI can show progress
95                                    yield Ok(RawStreamingChoice::ToolCallDelta {
96                                        id: tool_call.id.clone(),
97                                        content: rig::streaming::ToolCallDeltaContent::Delta(delta),
98                                    });
99                                }
100                            },
101                            aws_bedrock::ContentBlockDelta::ReasoningContent(reasoning) => {
102                                match reasoning {
103                                    aws_bedrock::ReasoningContentBlockDelta::Text(text) => {
104                                        if current_reasoning.is_none() {
105                                            current_reasoning = Some(ReasoningState::default());
106                                        }
107
108                                        if let Some(ref mut state) = current_reasoning {
109                                            state.content.push_str(text.as_str());
110                                        }
111
112                                        if !text.is_empty() {
113                                            yield Ok(RawStreamingChoice::ReasoningDelta {
114                                                reasoning: text.clone(),
115                                                id: None,
116                                            })
117                                        }
118                                    },
119                                    aws_bedrock::ReasoningContentBlockDelta::Signature(signature) => {
120                                        if current_reasoning.is_none() {
121                                            current_reasoning = Some(ReasoningState::default());
122                                        }
123
124                                        if let Some(ref mut state) = current_reasoning {
125                                            state.signature = Some(signature.clone());
126                                        }
127                                    },
128                                    _ => {}
129                                }
130                            },
131                            _ => {}
132                        }
133                    },
134                    aws_bedrock::ConverseStreamOutput::ContentBlockStart(event) => {
135                        match event.start.ok_or(CompletionError::ProviderError("ContentBlockStart has no data".into()))? {
136                            aws_bedrock::ContentBlockStart::ToolUse(tool_use) => {
137                                current_tool_call = Some(ToolCallState {
138                                    name: tool_use.name,
139                                    id: tool_use.tool_use_id,
140                                    input_json: String::new(),
141                                });
142                            },
143                            _ => yield Err(CompletionError::ProviderError("Stream is empty".into()))
144                        }
145                    },
146                    aws_bedrock::ConverseStreamOutput::ContentBlockStop(_event) => {
147                        if let Some(reasoning_state) = current_reasoning.take()
148                            && !reasoning_state.content.is_empty() {
149                                yield Ok(RawStreamingChoice::Reasoning {
150                                    reasoning: reasoning_state.content,
151                                    id: None,
152                                    signature: reasoning_state.signature,
153                                })
154                            }
155                    },
156                    aws_bedrock::ConverseStreamOutput::MessageStop(message_stop_event) => {
157                        match message_stop_event.stop_reason {
158                            aws_bedrock::StopReason::ToolUse => {
159                                if let Some(tool_call) = current_tool_call.take() {
160                                    // Handle empty input_json for tools with no parameters
161                                    let tool_input = if tool_call.input_json.is_empty() {
162                                        serde_json::json!({})
163                                    } else {
164                                        serde_json::from_str(tool_call.input_json.as_str())?
165                                    };
166                                    yield Ok(RawStreamingChoice::ToolCall(RawStreamingToolCall::new(tool_call.id, tool_call.name, tool_input)));
167                                } else {
168                                    yield Err(CompletionError::ProviderError("Failed to call tool".into()))
169                                }
170                            }
171                            aws_bedrock::StopReason::MaxTokens => {
172                                yield Err(CompletionError::ProviderError("Exceeded max tokens".into()))
173                            }
174                            _ => {}
175                        }
176                    },
177                    aws_bedrock::ConverseStreamOutput::Metadata(metadata_event) => {
178                        // Extract usage information from metadata
179                        if let Some(usage) = metadata_event.usage {
180                            yield Ok(RawStreamingChoice::FinalResponse(BedrockStreamingResponse {
181                                usage: Some(BedrockUsage {
182                                    input_tokens: usage.input_tokens,
183                                    output_tokens: usage.output_tokens,
184                                    total_tokens: usage.total_tokens,
185                                }),
186                            }));
187                        }
188                    },
189                    _ => {}
190                }
191            }
192        });
193
194        Ok(StreamingCompletionResponse::stream(stream))
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_bedrock_usage_creation() {
204        let usage = BedrockUsage {
205            input_tokens: 100,
206            output_tokens: 50,
207            total_tokens: 150,
208        };
209
210        assert_eq!(usage.input_tokens, 100);
211        assert_eq!(usage.output_tokens, 50);
212        assert_eq!(usage.total_tokens, 150);
213    }
214
215    #[test]
216    fn test_bedrock_streaming_response_with_usage() {
217        let response = BedrockStreamingResponse {
218            usage: Some(BedrockUsage {
219                input_tokens: 200,
220                output_tokens: 75,
221                total_tokens: 275,
222            }),
223        };
224
225        let rig_usage = response.token_usage();
226        assert!(rig_usage.is_some());
227
228        let usage = rig_usage.unwrap();
229        assert_eq!(usage.input_tokens, 200);
230        assert_eq!(usage.output_tokens, 75);
231        assert_eq!(usage.total_tokens, 275);
232    }
233
234    #[test]
235    fn test_bedrock_streaming_response_without_usage() {
236        let response = BedrockStreamingResponse { usage: None };
237
238        let rig_usage = response.token_usage();
239        assert!(rig_usage.is_none());
240    }
241
242    #[test]
243    fn test_get_token_usage_trait() {
244        let response = BedrockStreamingResponse {
245            usage: Some(BedrockUsage {
246                input_tokens: 448,
247                output_tokens: 68,
248                total_tokens: 516,
249            }),
250        };
251
252        // Test that GetTokenUsage trait is properly implemented
253        let usage = response.token_usage().expect("Usage should be present");
254        assert_eq!(usage.input_tokens, 448);
255        assert_eq!(usage.output_tokens, 68);
256        assert_eq!(usage.total_tokens, 516);
257    }
258
259    #[test]
260    fn test_bedrock_usage_serde() {
261        let usage = BedrockUsage {
262            input_tokens: 100,
263            output_tokens: 50,
264            total_tokens: 150,
265        };
266
267        // Test serialization
268        let json = serde_json::to_string(&usage).expect("Should serialize");
269        assert!(json.contains("\"input_tokens\":100"));
270        assert!(json.contains("\"output_tokens\":50"));
271        assert!(json.contains("\"total_tokens\":150"));
272
273        // Test deserialization
274        let deserialized: BedrockUsage = serde_json::from_str(&json).expect("Should deserialize");
275        assert_eq!(deserialized.input_tokens, usage.input_tokens);
276        assert_eq!(deserialized.output_tokens, usage.output_tokens);
277        assert_eq!(deserialized.total_tokens, usage.total_tokens);
278    }
279
280    #[test]
281    fn test_bedrock_streaming_response_serde() {
282        let response = BedrockStreamingResponse {
283            usage: Some(BedrockUsage {
284                input_tokens: 200,
285                output_tokens: 75,
286                total_tokens: 275,
287            }),
288        };
289
290        // Test serialization
291        let json = serde_json::to_string(&response).expect("Should serialize");
292        assert!(json.contains("\"input_tokens\":200"));
293
294        // Test deserialization
295        let deserialized: BedrockStreamingResponse =
296            serde_json::from_str(&json).expect("Should deserialize");
297        assert!(deserialized.usage.is_some());
298        let usage = deserialized.usage.unwrap();
299        assert_eq!(usage.input_tokens, 200);
300        assert_eq!(usage.output_tokens, 75);
301        assert_eq!(usage.total_tokens, 275);
302    }
303
304    #[test]
305    fn test_reasoning_state_default() {
306        // Test that ReasoningState defaults are correct
307        let state = ReasoningState::default();
308        assert_eq!(state.content, "");
309        assert_eq!(state.signature, None);
310    }
311
312    #[test]
313    fn test_reasoning_state_accumulate_content() {
314        // Test accumulating content in ReasoningState
315        let mut state = ReasoningState::default();
316        state.content.push_str("First chunk");
317        state.content.push_str(" Second chunk");
318        state.content.push_str(" Third chunk");
319
320        assert_eq!(state.content, "First chunk Second chunk Third chunk");
321        assert_eq!(state.signature, None);
322    }
323
324    #[test]
325    fn test_reasoning_state_with_signature() {
326        // Test ReasoningState with signature
327        let mut state = ReasoningState::default();
328        state.content.push_str("Reasoning content");
329        state.signature = Some("test_signature_456".to_string());
330
331        assert_eq!(state.content, "Reasoning content");
332        assert_eq!(state.signature, Some("test_signature_456".to_string()));
333    }
334
335    #[test]
336    fn test_reasoning_state_empty_content() {
337        // Test that ReasoningState can have empty content
338        let state = ReasoningState {
339            signature: Some("signature_only".to_string()),
340            ..Default::default()
341        };
342
343        assert_eq!(state.content, "");
344        assert!(state.signature.is_some());
345    }
346
347    #[test]
348    fn test_tool_call_state_default() {
349        // Test that ToolCallState defaults are correct
350        let state = ToolCallState::default();
351        assert_eq!(state.name, "");
352        assert_eq!(state.id, "");
353        assert_eq!(state.input_json, "");
354    }
355
356    #[test]
357    fn test_tool_call_state_accumulate_json() {
358        // Test accumulating JSON input in ToolCallState
359        let mut state = ToolCallState {
360            name: "my_tool".to_string(),
361            id: "tool_123".to_string(),
362            input_json: String::new(),
363        };
364
365        state.input_json.push_str("{\"arg1\":");
366        state.input_json.push_str("\"value1\"");
367        state.input_json.push('}');
368
369        assert_eq!(state.name, "my_tool");
370        assert_eq!(state.id, "tool_123");
371        assert_eq!(state.input_json, "{\"arg1\":\"value1\"}");
372    }
373
374    #[test]
375    fn test_tool_call_state_empty_accumulation() {
376        let state = ToolCallState {
377            name: "test_tool".to_string(),
378            id: "tool_abc".to_string(),
379            input_json: String::new(),
380        };
381
382        assert_eq!(state.name, "test_tool");
383        assert_eq!(state.id, "tool_abc");
384        assert!(state.input_json.is_empty());
385    }
386
387    #[test]
388    fn test_tool_call_state_single_chunk() {
389        let mut state = ToolCallState {
390            name: "get_weather".to_string(),
391            id: "call_123".to_string(),
392            input_json: String::new(),
393        };
394
395        state.input_json.push_str("{\"location\":\"Paris\"}");
396
397        assert_eq!(state.input_json, "{\"location\":\"Paris\"}");
398    }
399
400    #[test]
401    fn test_tool_call_state_multiple_small_chunks() {
402        let mut state = ToolCallState {
403            name: "search".to_string(),
404            id: "call_xyz".to_string(),
405            input_json: String::new(),
406        };
407
408        // Simulate multiple small chunks arriving
409        let chunks = vec!["{", "\"q", "uery", "\":", "\"R", "ust", "\"}"];
410
411        for chunk in chunks {
412            state.input_json.push_str(chunk);
413        }
414
415        assert_eq!(state.input_json, "{\"query\":\"Rust\"}");
416    }
417
418    #[test]
419    fn test_tool_call_state_complex_json_accumulation() {
420        let mut state = ToolCallState {
421            name: "analyze_data".to_string(),
422            id: "call_456".to_string(),
423            input_json: String::new(),
424        };
425
426        // Simulate accumulating a complex nested JSON
427        state.input_json.push_str("{\"data\":{");
428        state.input_json.push_str("\"values\":[1,2,3],");
429        state
430            .input_json
431            .push_str("\"metadata\":{\"source\":\"api\"}");
432        state.input_json.push_str("}}");
433
434        assert_eq!(
435            state.input_json,
436            "{\"data\":{\"values\":[1,2,3],\"metadata\":{\"source\":\"api\"}}}"
437        );
438
439        // Verify it's valid JSON
440        let parsed: serde_json::Value =
441            serde_json::from_str(&state.input_json).expect("Should parse as valid JSON");
442        assert!(parsed.is_object());
443    }
444
445    #[test]
446    fn test_reasoning_state_accumulation() {
447        let mut state = ReasoningState::default();
448
449        state.content.push_str("First, ");
450        state.content.push_str("I need to ");
451        state.content.push_str("analyze the problem.");
452
453        assert_eq!(state.content, "First, I need to analyze the problem.");
454        assert!(state.signature.is_none());
455    }
456
457    #[test]
458    fn test_reasoning_state_with_signature_accumulation() {
459        let mut state = ReasoningState::default();
460
461        state.content.push_str("Reasoning content here");
462        state.signature = Some("sig_part1".to_string());
463
464        // Simulate signature being built up (in practice it comes in one chunk)
465        if let Some(ref mut sig) = state.signature {
466            sig.push_str("_part2");
467        }
468
469        assert_eq!(state.content, "Reasoning content here");
470        assert_eq!(state.signature, Some("sig_part1_part2".to_string()));
471    }
472}