Skip to main content

rig_bedrock/
streaming.rs

1use crate::types::completion_request::AwsCompletionRequest;
2use crate::{
3    completion::{CompletionModel, resolve_request_model},
4    types::errors::AwsSdkConverseStreamError,
5};
6use async_stream::stream;
7use aws_sdk_bedrockruntime::types as aws_bedrock;
8use rig::completion::GetTokenUsage;
9use rig::streaming::StreamingCompletionResponse;
10use rig::{
11    completion::CompletionError,
12    message::ReasoningContent,
13    streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent},
14};
15use serde::{Deserialize, Serialize};
16
17#[derive(Clone, Deserialize, Serialize)]
18pub struct BedrockStreamingResponse {
19    pub usage: Option<BedrockUsage>,
20}
21
22#[derive(Clone, Deserialize, Serialize)]
23pub struct BedrockUsage {
24    pub input_tokens: i32,
25    pub output_tokens: i32,
26    pub total_tokens: i32,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub cache_read_input_tokens: Option<i32>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub cache_write_input_tokens: Option<i32>,
31}
32
33impl GetTokenUsage for BedrockStreamingResponse {
34    fn token_usage(&self) -> Option<rig::completion::Usage> {
35        self.usage.as_ref().map(|u| rig::completion::Usage {
36            input_tokens: u.input_tokens as u64,
37            output_tokens: u.output_tokens as u64,
38            total_tokens: u.total_tokens as u64,
39            cached_input_tokens: u.cache_read_input_tokens.unwrap_or_default() as u64,
40            cache_creation_input_tokens: u.cache_write_input_tokens.unwrap_or_default() as u64,
41        })
42    }
43}
44
45#[derive(Default)]
46struct ToolCallState {
47    name: String,
48    id: String,
49    internal_call_id: String,
50    input_json: String,
51}
52
53#[derive(Default)]
54struct ReasoningState {
55    content: String,
56    signature: Option<String>,
57}
58
59impl CompletionModel {
60    pub(crate) async fn stream(
61        &self,
62        completion_request: rig::completion::CompletionRequest,
63    ) -> Result<StreamingCompletionResponse<BedrockStreamingResponse>, CompletionError> {
64        let request_model = resolve_request_model(&self.model, &completion_request);
65        let request = AwsCompletionRequest {
66            inner: completion_request,
67            prompt_caching: self.prompt_caching,
68        };
69
70        let mut converse_builder = self
71            .client
72            .get_inner()
73            .await
74            .converse_stream()
75            .model_id(request_model);
76
77        let tool_config = request.tools_config()?;
78        let prompt_with_history = request.messages()?;
79        converse_builder = converse_builder
80            .set_additional_model_request_fields(request.additional_params())
81            .set_inference_config(request.inference_config())
82            .set_tool_config(tool_config)
83            .set_system(request.system_prompt())
84            .set_messages(Some(prompt_with_history));
85
86        let response = converse_builder.send().await.map_err(|sdk_error| {
87            Into::<CompletionError>::into(AwsSdkConverseStreamError(sdk_error))
88        })?;
89
90        let stream = Box::pin(stream! {
91            let mut current_tool_call: Option<ToolCallState> = None;
92            let mut current_reasoning: Option<ReasoningState> = None;
93            let mut stream = response.stream;
94            while let Ok(Some(output)) = stream.recv().await {
95                match output {
96                    aws_bedrock::ConverseStreamOutput::ContentBlockDelta(event) => {
97                        let delta = event.delta.ok_or(CompletionError::ProviderError("The delta for a content block is missing".into()))?;
98                        match delta {
99                            aws_bedrock::ContentBlockDelta::Text(text) => {
100                                if current_tool_call.is_none() {
101                                    yield Ok(RawStreamingChoice::Message(text))
102                                }
103                            },
104                            aws_bedrock::ContentBlockDelta::ToolUse(tool) => {
105                                if let Some(ref mut tool_call) = current_tool_call {
106                                    let delta = tool.input().to_string();
107                                    tool_call.input_json.push_str(&delta);
108
109                                    // Emit the delta so UI can show progress
110                                    yield Ok(RawStreamingChoice::ToolCallDelta {
111                                        id: tool_call.id.clone(),
112                                        internal_call_id: tool_call.internal_call_id.clone(),
113                                        content: ToolCallDeltaContent::Delta(delta),
114                                    });
115                                }
116                            },
117                            aws_bedrock::ContentBlockDelta::ReasoningContent(reasoning) => {
118                                match reasoning {
119                                    aws_bedrock::ReasoningContentBlockDelta::Text(text) => {
120                                        if current_reasoning.is_none() {
121                                            current_reasoning = Some(ReasoningState::default());
122                                        }
123
124                                        if let Some(ref mut state) = current_reasoning {
125                                            state.content.push_str(text.as_str());
126                                        }
127
128                                        if !text.is_empty() {
129                                            yield Ok(RawStreamingChoice::ReasoningDelta {
130                                                reasoning: text.clone(),
131                                                id: None,
132                                            })
133                                        }
134                                    },
135                                    aws_bedrock::ReasoningContentBlockDelta::Signature(signature) => {
136                                        if current_reasoning.is_none() {
137                                            current_reasoning = Some(ReasoningState::default());
138                                        }
139
140                                        if let Some(ref mut state) = current_reasoning {
141                                            state.signature = Some(signature.clone());
142                                        }
143                                    },
144                                    _ => {}
145                                }
146                            },
147                            _ => {}
148                        }
149                    },
150                    aws_bedrock::ConverseStreamOutput::ContentBlockStart(event) => {
151                        match event.start.ok_or(CompletionError::ProviderError("ContentBlockStart has no data".into()))? {
152                            aws_bedrock::ContentBlockStart::ToolUse(tool_use) => {
153                                let internal_call_id = nanoid::nanoid!();
154                                current_tool_call = Some(ToolCallState {
155                                    name: tool_use.name.clone(),
156                                    id: tool_use.tool_use_id.clone(),
157                                    internal_call_id: internal_call_id.clone(),
158                                    input_json: String::new(),
159                                });
160                                yield Ok(RawStreamingChoice::ToolCallDelta {
161                                    id: tool_use.tool_use_id,
162                                    internal_call_id,
163                                    content: ToolCallDeltaContent::Name(tool_use.name),
164                                });
165                            },
166                            _ => yield Err(CompletionError::ProviderError("Stream is empty".into()))
167                        }
168                    },
169                    aws_bedrock::ConverseStreamOutput::ContentBlockStop(_event) => {
170                        if let Some(reasoning_state) = current_reasoning.take()
171                            && !reasoning_state.content.is_empty() {
172                                yield Ok(RawStreamingChoice::Reasoning {
173                                    id: None,
174                                    content: ReasoningContent::Text {
175                                        text: reasoning_state.content,
176                                        signature: reasoning_state.signature,
177                                    },
178                                })
179                            }
180                    },
181                    aws_bedrock::ConverseStreamOutput::MessageStop(message_stop_event) => {
182                        match message_stop_event.stop_reason {
183                            aws_bedrock::StopReason::ToolUse => {
184                                if let Some(tool_call) = current_tool_call.take() {
185                                    // Handle empty input_json for tools with no parameters
186                                    let tool_input = if tool_call.input_json.is_empty() {
187                                        serde_json::json!({})
188                                    } else {
189                                        serde_json::from_str(tool_call.input_json.as_str())?
190                                    };
191                                    yield Ok(RawStreamingChoice::ToolCall(
192                                        RawStreamingToolCall::new(tool_call.id, tool_call.name, tool_input)
193                                            .with_internal_call_id(tool_call.internal_call_id)
194                                    ));
195                                } else {
196                                    yield Err(CompletionError::ProviderError("Failed to call tool".into()))
197                                }
198                            }
199                            aws_bedrock::StopReason::MaxTokens => {
200                                yield Err(CompletionError::ProviderError("Exceeded max tokens".into()))
201                            }
202                            _ => {}
203                        }
204                    },
205                    aws_bedrock::ConverseStreamOutput::Metadata(metadata_event) => {
206                        // Extract usage information from metadata
207                        if let Some(usage) = metadata_event.usage {
208                            yield Ok(RawStreamingChoice::FinalResponse(BedrockStreamingResponse {
209                                usage: Some(BedrockUsage {
210                                    input_tokens: usage.input_tokens,
211                                    output_tokens: usage.output_tokens,
212                                    total_tokens: usage.total_tokens,
213                                    cache_read_input_tokens: usage.cache_read_input_tokens,
214                                    cache_write_input_tokens: usage.cache_write_input_tokens,
215                                }),
216                            }));
217                        }
218                    },
219                    _ => {}
220                }
221            }
222        });
223
224        Ok(StreamingCompletionResponse::stream(stream))
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_bedrock_usage_creation() {
234        let usage = BedrockUsage {
235            input_tokens: 100,
236            output_tokens: 50,
237            total_tokens: 150,
238            cache_read_input_tokens: None,
239            cache_write_input_tokens: None,
240        };
241
242        assert_eq!(usage.input_tokens, 100);
243        assert_eq!(usage.output_tokens, 50);
244        assert_eq!(usage.total_tokens, 150);
245    }
246
247    #[test]
248    fn test_bedrock_streaming_response_with_usage() {
249        let response = BedrockStreamingResponse {
250            usage: Some(BedrockUsage {
251                input_tokens: 200,
252                output_tokens: 75,
253                total_tokens: 275,
254                cache_read_input_tokens: Some(40),
255                cache_write_input_tokens: Some(10),
256            }),
257        };
258
259        let rig_usage = response.token_usage();
260        assert!(rig_usage.is_some());
261
262        let usage = rig_usage.unwrap();
263        assert_eq!(usage.input_tokens, 200);
264        assert_eq!(usage.output_tokens, 75);
265        assert_eq!(usage.total_tokens, 275);
266        assert_eq!(usage.cached_input_tokens, 40);
267        assert_eq!(usage.cache_creation_input_tokens, 10);
268    }
269
270    #[test]
271    fn test_bedrock_streaming_response_without_usage() {
272        let response = BedrockStreamingResponse { usage: None };
273
274        let rig_usage = response.token_usage();
275        assert!(rig_usage.is_none());
276    }
277
278    #[test]
279    fn test_get_token_usage_trait() {
280        let response = BedrockStreamingResponse {
281            usage: Some(BedrockUsage {
282                input_tokens: 448,
283                output_tokens: 68,
284                total_tokens: 516,
285                cache_read_input_tokens: Some(80),
286                cache_write_input_tokens: Some(20),
287            }),
288        };
289
290        // Test that GetTokenUsage trait is properly implemented
291        let usage = response.token_usage().expect("Usage should be present");
292        assert_eq!(usage.input_tokens, 448);
293        assert_eq!(usage.output_tokens, 68);
294        assert_eq!(usage.total_tokens, 516);
295        assert_eq!(usage.cached_input_tokens, 80);
296        assert_eq!(usage.cache_creation_input_tokens, 20);
297    }
298
299    #[test]
300    fn test_bedrock_usage_serde() {
301        let usage = BedrockUsage {
302            input_tokens: 100,
303            output_tokens: 50,
304            total_tokens: 150,
305            cache_read_input_tokens: Some(25),
306            cache_write_input_tokens: Some(5),
307        };
308
309        // Test serialization
310        let json = serde_json::to_string(&usage).expect("Should serialize");
311        assert!(json.contains("\"input_tokens\":100"));
312        assert!(json.contains("\"output_tokens\":50"));
313        assert!(json.contains("\"total_tokens\":150"));
314
315        // Test deserialization
316        let deserialized: BedrockUsage = serde_json::from_str(&json).expect("Should deserialize");
317        assert_eq!(deserialized.input_tokens, usage.input_tokens);
318        assert_eq!(deserialized.output_tokens, usage.output_tokens);
319        assert_eq!(deserialized.total_tokens, usage.total_tokens);
320        assert_eq!(
321            deserialized.cache_read_input_tokens,
322            usage.cache_read_input_tokens
323        );
324        assert_eq!(
325            deserialized.cache_write_input_tokens,
326            usage.cache_write_input_tokens
327        );
328    }
329
330    #[test]
331    fn test_bedrock_streaming_response_serde() {
332        let response = BedrockStreamingResponse {
333            usage: Some(BedrockUsage {
334                input_tokens: 200,
335                output_tokens: 75,
336                total_tokens: 275,
337                cache_read_input_tokens: Some(30),
338                cache_write_input_tokens: Some(15),
339            }),
340        };
341
342        // Test serialization
343        let json = serde_json::to_string(&response).expect("Should serialize");
344        assert!(json.contains("\"input_tokens\":200"));
345
346        // Test deserialization
347        let deserialized: BedrockStreamingResponse =
348            serde_json::from_str(&json).expect("Should deserialize");
349        assert!(deserialized.usage.is_some());
350        let usage = deserialized.usage.unwrap();
351        assert_eq!(usage.input_tokens, 200);
352        assert_eq!(usage.output_tokens, 75);
353        assert_eq!(usage.total_tokens, 275);
354        assert_eq!(usage.cache_read_input_tokens, Some(30));
355        assert_eq!(usage.cache_write_input_tokens, Some(15));
356    }
357
358    #[test]
359    fn test_reasoning_state_default() {
360        // Test that ReasoningState defaults are correct
361        let state = ReasoningState::default();
362        assert_eq!(state.content, "");
363        assert_eq!(state.signature, None);
364    }
365
366    #[test]
367    fn test_reasoning_state_accumulate_content() {
368        // Test accumulating content in ReasoningState
369        let mut state = ReasoningState::default();
370        state.content.push_str("First chunk");
371        state.content.push_str(" Second chunk");
372        state.content.push_str(" Third chunk");
373
374        assert_eq!(state.content, "First chunk Second chunk Third chunk");
375        assert_eq!(state.signature, None);
376    }
377
378    #[test]
379    fn test_reasoning_state_with_signature() {
380        // Test ReasoningState with signature
381        let mut state = ReasoningState::default();
382        state.content.push_str("Reasoning content");
383        state.signature = Some("test_signature_456".to_string());
384
385        assert_eq!(state.content, "Reasoning content");
386        assert_eq!(state.signature, Some("test_signature_456".to_string()));
387    }
388
389    #[test]
390    fn test_reasoning_state_empty_content() {
391        // Test that ReasoningState can have empty content
392        let state = ReasoningState {
393            signature: Some("signature_only".to_string()),
394            ..Default::default()
395        };
396
397        assert_eq!(state.content, "");
398        assert!(state.signature.is_some());
399    }
400
401    #[test]
402    fn test_tool_call_state_default() {
403        // Test that ToolCallState defaults are correct
404        let state = ToolCallState::default();
405        assert_eq!(state.name, "");
406        assert_eq!(state.id, "");
407        assert_eq!(state.input_json, "");
408    }
409
410    #[test]
411    fn test_tool_call_state_accumulate_json() {
412        // Test accumulating JSON input in ToolCallState
413        let mut state = ToolCallState {
414            name: "my_tool".to_string(),
415            id: "tool_123".to_string(),
416            internal_call_id: nanoid::nanoid!(),
417            input_json: String::new(),
418        };
419
420        state.input_json.push_str("{\"arg1\":");
421        state.input_json.push_str("\"value1\"");
422        state.input_json.push('}');
423
424        assert_eq!(state.name, "my_tool");
425        assert_eq!(state.id, "tool_123");
426        assert_eq!(state.input_json, "{\"arg1\":\"value1\"}");
427    }
428
429    #[test]
430    fn test_tool_call_state_empty_accumulation() {
431        let state = ToolCallState {
432            name: "test_tool".to_string(),
433            id: "tool_abc".to_string(),
434            internal_call_id: nanoid::nanoid!(),
435            input_json: String::new(),
436        };
437
438        assert_eq!(state.name, "test_tool");
439        assert_eq!(state.id, "tool_abc");
440        assert!(state.input_json.is_empty());
441    }
442
443    #[test]
444    fn test_tool_call_state_single_chunk() {
445        let mut state = ToolCallState {
446            name: "get_weather".to_string(),
447            id: "call_123".to_string(),
448            internal_call_id: nanoid::nanoid!(),
449            input_json: String::new(),
450        };
451
452        state.input_json.push_str("{\"location\":\"Paris\"}");
453
454        assert_eq!(state.input_json, "{\"location\":\"Paris\"}");
455    }
456
457    #[test]
458    fn test_tool_call_state_multiple_small_chunks() {
459        let mut state = ToolCallState {
460            name: "search".to_string(),
461            id: "call_xyz".to_string(),
462            internal_call_id: nanoid::nanoid!(),
463            input_json: String::new(),
464        };
465
466        // Simulate multiple small chunks arriving
467        let chunks = vec!["{", "\"q", "uery", "\":", "\"R", "ust", "\"}"];
468
469        for chunk in chunks {
470            state.input_json.push_str(chunk);
471        }
472
473        assert_eq!(state.input_json, "{\"query\":\"Rust\"}");
474    }
475
476    #[test]
477    fn test_tool_call_state_complex_json_accumulation() {
478        let mut state = ToolCallState {
479            name: "analyze_data".to_string(),
480            id: "call_456".to_string(),
481            internal_call_id: nanoid::nanoid!(),
482            input_json: String::new(),
483        };
484
485        // Simulate accumulating a complex nested JSON
486        state.input_json.push_str("{\"data\":{");
487        state.input_json.push_str("\"values\":[1,2,3],");
488        state
489            .input_json
490            .push_str("\"metadata\":{\"source\":\"api\"}");
491        state.input_json.push_str("}}");
492
493        assert_eq!(
494            state.input_json,
495            "{\"data\":{\"values\":[1,2,3],\"metadata\":{\"source\":\"api\"}}}"
496        );
497
498        // Verify it's valid JSON
499        let parsed: serde_json::Value =
500            serde_json::from_str(&state.input_json).expect("Should parse as valid JSON");
501        assert!(parsed.is_object());
502    }
503
504    #[test]
505    fn test_reasoning_state_accumulation() {
506        let mut state = ReasoningState::default();
507
508        state.content.push_str("First, ");
509        state.content.push_str("I need to ");
510        state.content.push_str("analyze the problem.");
511
512        assert_eq!(state.content, "First, I need to analyze the problem.");
513        assert!(state.signature.is_none());
514    }
515
516    #[test]
517    fn test_reasoning_state_with_signature_accumulation() {
518        let mut state = ReasoningState::default();
519
520        state.content.push_str("Reasoning content here");
521        state.signature = Some("sig_part1".to_string());
522
523        // Simulate signature being built up (in practice it comes in one chunk)
524        if let Some(ref mut sig) = state.signature {
525            sig.push_str("_part2");
526        }
527
528        assert_eq!(state.content, "Reasoning content here");
529        assert_eq!(state.signature, Some("sig_part1_part2".to_string()));
530    }
531}