1use crate::anthropic::v1::request::{ContentBlockParam, MessageParam};
2use crate::prompt::Role;
3use crate::prompt::{MessageNum, ResponseContent};
4use crate::traits::{MessageResponseExt, ResponseAdapter};
5use crate::TypeError;
6use potato_util::utils::{construct_structured_response, TokenLogProbs};
7use potato_util::PyHelperFuncs;
8use pyo3::prelude::*;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
13#[pyclass]
14pub struct CitationCharLocation {
15 #[pyo3(get, set)]
16 pub cited_text: String,
17 #[pyo3(get, set)]
18 pub document_index: i32,
19 #[pyo3(get, set)]
20 pub document_title: String,
21 #[pyo3(get, set)]
22 pub end_char_index: i32,
23 #[pyo3(get, set)]
24 pub file_id: String,
25 #[pyo3(get, set)]
26 pub start_char_index: i32,
27 #[pyo3(get)]
28 #[serde(rename = "type")]
29 pub r#type: String,
30}
31
32#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
33#[pyclass]
34pub struct CitationPageLocation {
35 #[pyo3(get, set)]
36 pub cited_text: String,
37 #[pyo3(get, set)]
38 pub document_index: i32,
39 #[pyo3(get, set)]
40 pub document_title: String,
41 #[pyo3(get, set)]
42 pub end_page_number: i32,
43 #[pyo3(get, set)]
44 pub file_id: String,
45 #[pyo3(get, set)]
46 pub start_page_number: i32,
47 #[pyo3(get)]
48 #[serde(rename = "type")]
49 pub r#type: String,
50}
51
52#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
53#[pyclass]
54pub struct CitationContentBlockLocation {
55 #[pyo3(get, set)]
56 pub cited_text: String,
57 #[pyo3(get, set)]
58 pub document_index: i32,
59 #[pyo3(get, set)]
60 pub document_title: String,
61 #[pyo3(get, set)]
62 pub end_block_index: i32,
63 #[pyo3(get, set)]
64 pub file_id: String,
65 #[pyo3(get, set)]
66 pub start_block_index: i32,
67 #[pyo3(get)]
68 #[serde(rename = "type")]
69 pub r#type: String,
70}
71
72#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
73#[pyclass]
74pub struct CitationsWebSearchResultLocation {
75 #[pyo3(get, set)]
76 pub cited_text: String,
77 #[pyo3(get, set)]
78 pub encrypted_index: String,
79 #[pyo3(get, set)]
80 pub title: String,
81 #[pyo3(get)]
82 #[serde(rename = "type")]
83 pub r#type: String,
84 #[pyo3(get, set)]
85 pub url: String,
86}
87
88#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
89#[pyclass]
90pub struct CitationsSearchResultLocation {
91 #[pyo3(get, set)]
92 pub cited_text: String,
93 #[pyo3(get, set)]
94 pub end_block_index: i32,
95 #[pyo3(get, set)]
96 pub search_result_index: i32,
97 #[pyo3(get, set)]
98 pub source: String,
99 #[pyo3(get, set)]
100 pub start_block_index: i32,
101 #[pyo3(get, set)]
102 pub title: String,
103 #[pyo3(get)]
104 #[serde(rename = "type")]
105 pub r#type: String,
106}
107
108#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
110#[serde(untagged)]
111pub enum TextCitation {
112 CharLocation(CitationCharLocation),
113 PageLocation(CitationPageLocation),
114 ContentBlockLocation(CitationContentBlockLocation),
115 WebSearchResultLocation(CitationsWebSearchResultLocation),
116 SearchResultLocation(CitationsSearchResultLocation),
117}
118
119#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
121#[pyclass]
122pub struct TextBlock {
123 #[pyo3(get, set)]
124 pub text: String,
125 #[serde(skip_serializing_if = "Option::is_none")]
126 pub citations: Option<Vec<TextCitation>>,
127 #[pyo3(get)]
128 #[serde(rename = "type")]
129 pub r#type: String,
130}
131
132#[pymethods]
133impl TextBlock {
134 #[getter]
135 pub fn citations<'py>(
136 &self,
137 py: Python<'py>,
138 ) -> Result<Option<Vec<Bound<'py, PyAny>>>, TypeError> {
139 match &self.citations {
140 None => Ok(None),
141 Some(cits) => {
142 let py_citations: Result<Vec<_>, _> = cits
143 .iter()
144 .map(|cit| match cit {
145 TextCitation::CharLocation(c) => c.clone().into_bound_py_any(py),
146 TextCitation::PageLocation(c) => c.clone().into_bound_py_any(py),
147 TextCitation::ContentBlockLocation(c) => c.clone().into_bound_py_any(py),
148 TextCitation::WebSearchResultLocation(c) => c.clone().into_bound_py_any(py),
149 TextCitation::SearchResultLocation(c) => c.clone().into_bound_py_any(py),
150 })
151 .collect();
152 Ok(Some(py_citations?))
153 }
154 }
155 }
156}
157
158#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
159#[pyclass]
160pub struct ThinkingBlock {
161 #[pyo3(get, set)]
162 pub thinking: String,
163 #[pyo3(get, set)]
164 pub signature: Option<String>,
165 #[pyo3(get)]
166 #[serde(rename = "type")]
167 pub r#type: String,
168}
169
170#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
171#[pyclass]
172pub struct RedactedThinkingBlock {
173 #[pyo3(get, set)]
174 pub data: String,
175 #[pyo3(get)]
176 #[serde(rename = "type")]
177 pub r#type: String,
178}
179
180#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
181#[pyclass]
182pub struct ToolUseBlock {
183 #[pyo3(get, set)]
184 pub id: String,
185 #[pyo3(get, set)]
186 pub name: String,
187 pub input: Value,
188 #[pyo3(get)]
189 #[serde(rename = "type")]
190 pub r#type: String,
191}
192
193#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
194#[pyclass]
195pub struct ServerToolUseBlock {
196 #[pyo3(get, set)]
197 pub id: String,
198 #[pyo3(get, set)]
199 pub name: String,
200 pub input: Value,
201 #[pyo3(get)]
202 #[serde(rename = "type")]
203 pub r#type: String,
204}
205
206#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
207#[pyclass]
208pub struct WebSearchResultBlock {
209 #[pyo3(get, set)]
210 pub encrypted_content: String,
211 #[pyo3(get, set)]
212 pub page_age: Option<String>,
213 #[pyo3(get, set)]
214 pub title: String,
215 #[pyo3(get)]
216 #[serde(rename = "type")]
217 pub r#type: String,
218 #[pyo3(get, set)]
219 pub url: String,
220}
221
222#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
223#[pyclass]
224pub struct WebSearchToolResultError {
225 #[pyo3(get, set)]
226 pub error_code: String,
227 #[pyo3(get)]
228 #[serde(rename = "type")]
229 pub r#type: String,
230}
231
232#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
233#[serde(untagged)]
234pub enum WebSearchToolResultBlockContent {
235 Error(WebSearchToolResultError),
236 Results(Vec<WebSearchResultBlock>),
237}
238
239#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
241#[pyclass]
242pub struct WebSearchToolResultBlock {
243 pub content: WebSearchToolResultBlockContent,
244 #[pyo3(get, set)]
245 pub tool_use_id: String,
246 #[pyo3(get)]
247 #[serde(rename = "type")]
248 pub r#type: String,
249}
250
251#[pymethods]
252impl WebSearchToolResultBlock {
253 #[getter]
254 pub fn content<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
255 match &self.content {
256 WebSearchToolResultBlockContent::Error(err) => Ok(err.clone().into_bound_py_any(py)?),
257 WebSearchToolResultBlockContent::Results(results) => {
258 let py_list: Result<Vec<_>, _> = results
259 .iter()
260 .map(|r| r.clone().into_bound_py_any(py))
261 .collect();
262 Ok(py_list?.into_bound_py_any(py)?)
263 }
264 }
265 }
266}
267
268#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
269#[serde(untagged)]
270pub(crate) enum ResponseContentBlockInner {
271 Text(TextBlock),
272 Thinking(ThinkingBlock),
273 RedactedThinking(RedactedThinkingBlock),
274 ToolUse(ToolUseBlock),
275 ServerToolUse(ServerToolUseBlock),
276 WebSearchToolResult(WebSearchToolResultBlock),
277}
278
279#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
280pub struct ResponseContentBlock {
281 #[serde(flatten)]
282 inner: ResponseContentBlockInner,
283}
284
285impl ResponseContentBlock {
286 pub fn to_pyobject<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
288 match &self.inner {
289 ResponseContentBlockInner::Text(block) => Ok(block.clone().into_bound_py_any(py)?),
290 ResponseContentBlockInner::Thinking(block) => {
291 Ok(block.clone().into_bound_py_any(py)?)
292 }
293 ResponseContentBlockInner::RedactedThinking(block) => {
294 Ok(block.clone().into_bound_py_any(py)?)
295 }
296 ResponseContentBlockInner::ToolUse(block) => Ok(block.clone().into_bound_py_any(py)?),
297 ResponseContentBlockInner::ServerToolUse(block) => {
298 Ok(block.clone().into_bound_py_any(py)?)
299 }
300 ResponseContentBlockInner::WebSearchToolResult(block) => {
301 Ok(block.clone().into_bound_py_any(py)?)
302 }
303 }
304 }
305}
306
307impl MessageResponseExt for ResponseContentBlock {
308 fn to_message_num(&self) -> Result<MessageNum, TypeError> {
309 let content_block_param = ContentBlockParam::from_response_content_block(&self.inner)?;
311
312 let message_param = MessageParam {
314 content: vec![content_block_param],
315 role: Role::Assistant.to_string(),
316 };
317
318 Ok(MessageNum::AnthropicMessageV1(message_param))
320 }
321}
322
323#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
324#[serde(rename_all = "snake_case")]
325#[pyclass]
326pub enum StopReason {
327 EndTurn,
328 MaxTokens,
329 StopSequence,
330 ToolUse,
331}
332
333#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
334#[pyclass(name = "AnthropicUsage")]
335pub struct Usage {
336 #[pyo3(get)]
337 pub input_tokens: i32,
338 #[pyo3(get)]
339 pub output_tokens: i32,
340 #[serde(skip_serializing_if = "Option::is_none")]
341 #[pyo3(get)]
342 pub cache_creation_input_tokens: Option<i32>,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 #[pyo3(get)]
345 pub cache_read_input_tokens: Option<i32>,
346 #[pyo3(get)]
347 pub service_tier: Option<String>,
348}
349
350#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
351#[pyclass]
352pub struct AnthropicMessageResponse {
353 #[pyo3(get)]
354 pub id: String,
355 #[pyo3(get)]
356 pub model: String,
357 #[pyo3(get)]
358 pub role: String,
359 #[pyo3(get)]
360 pub stop_reason: Option<StopReason>,
361 #[pyo3(get)]
362 pub stop_sequence: Option<String>,
363 #[pyo3(get)]
364 pub r#type: String,
365 #[pyo3(get)]
366 pub usage: Usage,
367 pub content: Vec<ResponseContentBlock>,
368}
369
370#[pymethods]
371impl AnthropicMessageResponse {
372 #[getter]
373 pub fn content<'py>(&self, py: Python<'py>) -> Result<Vec<Bound<'py, PyAny>>, TypeError> {
374 self.content
375 .iter()
376 .map(|block| block.to_pyobject(py))
377 .collect::<Result<Vec<_>, _>>()
378 .map_err(|e| TypeError::Error(e.to_string()))
379 }
380}
381
382impl ResponseAdapter for AnthropicMessageResponse {
383 fn __str__(&self) -> String {
384 PyHelperFuncs::__str__(self)
385 }
386
387 fn is_empty(&self) -> bool {
388 self.content.is_empty()
389 }
390
391 fn to_bound_py_object<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
392 Ok(PyHelperFuncs::to_bound_py_object(py, self)?)
393 }
394
395 fn id(&self) -> &str {
396 &self.id
397 }
398
399 fn to_message_num(&self) -> Result<Vec<MessageNum>, TypeError> {
400 let mut results = Vec::new();
401 for content in &self.content {
402 match content.to_message_num() {
403 Ok(msg) => results.push(msg),
404 Err(e) => return Err(e),
405 }
406 }
407 Ok(results)
408 }
409
410 fn get_content(&self) -> ResponseContent {
411 ResponseContent::Anthropic(self.content.first().cloned().unwrap())
412 }
413
414 fn tool_call_output(&self) -> Option<Value> {
415 for block in &self.content {
416 if let ResponseContentBlockInner::ToolUse(tool_use_block) = &block.inner {
417 return serde_json::to_value(tool_use_block).ok();
418 }
419 }
420 None
421 }
422
423 fn structured_output<'py>(
424 &self,
425 py: Python<'py>,
426 output_model: Option<&Bound<'py, PyAny>>,
427 ) -> Result<Bound<'py, PyAny>, TypeError> {
428 if self.content.is_empty() {
429 return Ok(py.None().into_bound_py_any(py)?);
430 }
431
432 let inner = self.content.first().cloned().unwrap().inner;
433
434 match inner {
435 ResponseContentBlockInner::Text(block) => {
436 return Ok(construct_structured_response(py, block.text, output_model)?)
437 }
438 _ => return Ok(py.None().into_bound_py_any(py)?),
439 };
440 }
441
442 fn structured_output_value(&self) -> Option<Value> {
443 if self.content.is_empty() {
444 return None;
445 }
446
447 let inner = self.content.first().cloned().unwrap().inner;
448 match inner {
449 ResponseContentBlockInner::Text(block) => serde_json::from_str(&block.text).ok(),
450 _ => None,
451 }
452 }
453
454 fn usage<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
455 Ok(PyHelperFuncs::to_bound_py_object(py, &self.usage)?)
456 }
457
458 fn get_log_probs(&self) -> Vec<TokenLogProbs> {
459 Vec::new()
461 }
462
463 fn response_text(&self) -> String {
464 if self.content.is_empty() {
465 return String::new();
466 }
467
468 let inner = self.content.first().cloned().unwrap().inner;
469
470 match inner {
471 ResponseContentBlockInner::Text(block) => block.text,
472 _ => String::new(),
473 }
474 }
475}