anchor_chain/models/
claude_3.rs

1//! Module for interfacing with Claude 3 via AWS Bedrock.
2//!
3//! Provides the functionality to construct and send requests to Claude 3 models hosted
4//! on AWS Bedrock, facilitating integration of LLM processing within
5//! processing chains. This module is designed to handle text and image inputs, offering a
6//! flexible interface for various types of content.
7
8use std::fmt;
9
10use async_trait::async_trait;
11use aws_sdk_bedrockruntime::{primitives::Blob, Client};
12use serde::{Deserialize, Serialize};
13#[cfg(feature = "tracing")]
14use tracing::instrument;
15
16use crate::error::AnchorChainError;
17use crate::node::Node;
18
19/// Represents a source of an image to be processed by Claude 3, encapsulating the necessary
20/// details for image handling.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClaudeImageSource {
23    /// Specifies the data type of the source, currently only "base64" is supported.
24    #[serde(rename = "type")]
25    source_type: String,
26
27    /// Indicates the media type of the image, e.g., "image/jpeg".
28    media_type: String,
29
30    /// Contains the base64-encoded image data.
31    data: String,
32}
33
34/// Defines the content of a message for Claude 3, accommodating text and images.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ClaudeMessageContent {
37    /// The content type, e.g., "text".
38    #[serde(rename = "type")]
39    pub content_type: String,
40
41    /// The actual text content, if applicable.
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub text: Option<String>,
44
45    /// An image source, if applicable.
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub source: Option<ClaudeImageSource>,
48}
49
50/// Represents a message to be sent to Claude 3, comprising one or more content items.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ClaudeMessage {
53    /// The role of the message, e.g., "user".
54    pub role: Option<String>,
55
56    /// A vector of content items within the message.
57    pub content: Vec<ClaudeMessageContent>,
58}
59
60/// Struct to configure and send a request to Claude 3 model via AWS Bedrock.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62struct ClaudeMessagesRequest {
63    /// Specifies the version of the anthropic model to use.
64    anthropic_version: String,
65    /// Sets the maximum number of tokens to generate.
66    max_tokens: i32,
67    /// Contains the messages to process.
68    messages: Vec<ClaudeMessage>,
69
70    // Optional parameters for model invocation.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    system: Option<String>,
73
74    #[serde(skip_serializing_if = "Option::is_none")]
75    temperature: Option<f32>,
76
77    #[serde(skip_serializing_if = "Option::is_none")]
78    top_p: Option<f32>,
79
80    #[serde(skip_serializing_if = "Option::is_none")]
81    top_k: Option<i32>,
82
83    #[serde(skip_serializing_if = "Option::is_none")]
84    stop_sequences: Option<Vec<String>>,
85}
86
87/// Holds the response content from a Claude 3 processing request.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89struct ClaudeMessagesResponse {
90    /// The processed content returned by Claude.
91    content: Vec<ClaudeMessageContent>,
92}
93
94/// A processor for integrating Claude 3 LLM processing within a chain.
95///
96/// `Claude3Bedrock` allows for sending requests to Claude 3 models, handling both text and image inputs.
97/// It encapsulates the necessary details for AWS Bedrock interaction and provides an asynchronous
98/// interface for processing content through Claude 3.
99pub struct Claude3Bedrock {
100    /// The system prompt or context to use for all requests.
101    system_prompt: String,
102    /// The AWS Bedrock client for sending requests.
103    client: Client,
104}
105
106impl Claude3Bedrock {
107    /// Constructs a new `Claude3Bedrock` processor with the specified system prompt.
108    ///
109    /// Initializes the AWS Bedrock client using the environment's AWS configuration.
110    pub async fn new(system_prompt: String) -> Self {
111        let config = aws_config::load_from_env().await;
112        let client = Client::new(&config);
113        Claude3Bedrock {
114            client,
115            system_prompt,
116        }
117    }
118}
119
120#[async_trait]
121impl Node for Claude3Bedrock {
122    type Input = String;
123    type Output = String;
124
125    /// Processes the input through the Claude 3 model, returning the model's output.
126    ///
127    /// Constructs a request to the Claude 3 model with the provided input, sends it via
128    /// AWS Bedrock, and extracts the text content from the response.
129    #[cfg_attr(feature = "tracing", instrument(fields(system_prompt = self.system_prompt.as_str())))]
130    async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
131        let request = ClaudeMessagesRequest {
132            anthropic_version: "bedrock-2023-05-31".to_string(),
133            max_tokens: 512,
134            messages: vec![ClaudeMessage {
135                role: Some("user".to_string()),
136                content: vec![ClaudeMessageContent {
137                    content_type: "text".to_string(),
138                    text: Some(input.to_string()),
139                    source: None,
140                }],
141            }],
142            system: Some(self.system_prompt.clone()),
143            temperature: None,
144            top_p: None,
145            top_k: None,
146            stop_sequences: None,
147        };
148
149        let body_blob = Blob::new(serde_json::to_string(&request)?);
150        let response = self
151            .client
152            .invoke_model()
153            .model_id("anthropic.claude-3-sonnet-20240229-v1:0")
154            .body(body_blob)
155            .content_type("application/json")
156            .send()
157            .await;
158
159        let response_blob = response?.body;
160        let response: ClaudeMessagesResponse = serde_json::from_slice(&response_blob.into_inner())?;
161
162        if response.content.is_empty() {
163            return Err(AnchorChainError::EmptyResponseError);
164        }
165
166        Ok(response.content[0]
167            .text
168            .clone()
169            .expect("No text in response"))
170    }
171}
172
173impl fmt::Debug for Claude3Bedrock {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        f.debug_struct("Claude3Bedrock")
176            .field("system_prompt", &self.system_prompt)
177            .finish()
178    }
179}