anchor_chain/models/
claude_3.rs1use std::fmt;
9
10use async_trait::async_trait;
11use aws_sdk_bedrockruntime::{primitives::Blob, Client};
12use serde::{Deserialize, Serialize};
13#[cfg(feature = "tracing")]
14use tracing::instrument;
15
16use crate::error::AnchorChainError;
17use crate::node::Node;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClaudeImageSource {
23 #[serde(rename = "type")]
25 source_type: String,
26
27 media_type: String,
29
30 data: String,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ClaudeMessageContent {
37 #[serde(rename = "type")]
39 pub content_type: String,
40
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub text: Option<String>,
44
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub source: Option<ClaudeImageSource>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ClaudeMessage {
53 pub role: Option<String>,
55
56 pub content: Vec<ClaudeMessageContent>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62struct ClaudeMessagesRequest {
63 anthropic_version: String,
65 max_tokens: i32,
67 messages: Vec<ClaudeMessage>,
69
70 #[serde(skip_serializing_if = "Option::is_none")]
72 system: Option<String>,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
75 temperature: Option<f32>,
76
77 #[serde(skip_serializing_if = "Option::is_none")]
78 top_p: Option<f32>,
79
80 #[serde(skip_serializing_if = "Option::is_none")]
81 top_k: Option<i32>,
82
83 #[serde(skip_serializing_if = "Option::is_none")]
84 stop_sequences: Option<Vec<String>>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89struct ClaudeMessagesResponse {
90 content: Vec<ClaudeMessageContent>,
92}
93
94pub struct Claude3Bedrock {
100 system_prompt: String,
102 client: Client,
104}
105
106impl Claude3Bedrock {
107 pub async fn new(system_prompt: String) -> Self {
111 let config = aws_config::load_from_env().await;
112 let client = Client::new(&config);
113 Claude3Bedrock {
114 client,
115 system_prompt,
116 }
117 }
118}
119
120#[async_trait]
121impl Node for Claude3Bedrock {
122 type Input = String;
123 type Output = String;
124
125 #[cfg_attr(feature = "tracing", instrument(fields(system_prompt = self.system_prompt.as_str())))]
130 async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
131 let request = ClaudeMessagesRequest {
132 anthropic_version: "bedrock-2023-05-31".to_string(),
133 max_tokens: 512,
134 messages: vec![ClaudeMessage {
135 role: Some("user".to_string()),
136 content: vec![ClaudeMessageContent {
137 content_type: "text".to_string(),
138 text: Some(input.to_string()),
139 source: None,
140 }],
141 }],
142 system: Some(self.system_prompt.clone()),
143 temperature: None,
144 top_p: None,
145 top_k: None,
146 stop_sequences: None,
147 };
148
149 let body_blob = Blob::new(serde_json::to_string(&request)?);
150 let response = self
151 .client
152 .invoke_model()
153 .model_id("anthropic.claude-3-sonnet-20240229-v1:0")
154 .body(body_blob)
155 .content_type("application/json")
156 .send()
157 .await;
158
159 let response_blob = response?.body;
160 let response: ClaudeMessagesResponse = serde_json::from_slice(&response_blob.into_inner())?;
161
162 if response.content.is_empty() {
163 return Err(AnchorChainError::EmptyResponseError);
164 }
165
166 Ok(response.content[0]
167 .text
168 .clone()
169 .expect("No text in response"))
170 }
171}
172
173impl fmt::Debug for Claude3Bedrock {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 f.debug_struct("Claude3Bedrock")
176 .field("system_prompt", &self.system_prompt)
177 .finish()
178 }
179}