Skip to main content

potato_type/anthropic/v1/
request.rs

1use crate::anthropic::v1::response::{
2    ResponseContentBlockInner, TextCitation, WebSearchToolResultBlockContent,
3};
4use crate::common::get_image_media_types;
5use crate::google::v1::generate::request::{DataNum, GeminiContent, Part};
6use crate::openai::v1::chat::request::{ChatMessage, ContentPart, TextContentPart};
7use crate::prompt::builder::ProviderRequest;
8use crate::prompt::MessageNum;
9use crate::prompt::ModelSettings;
10use crate::tools::AgentToolDefinition;
11use crate::traits::get_var_regex;
12use crate::traits::{MessageConversion, MessageFactory, PromptMessageExt, RequestAdapter};
13use crate::TypeError;
14use crate::{Provider, SettingsType};
15use potato_util::{PyHelperFuncs, UtilError};
16use pyo3::prelude::*;
17use pyo3::types::PyList;
18use pyo3::IntoPyObjectExt;
19use pythonize::{depythonize, pythonize};
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22use std::collections::HashSet;
23
24/// common content types used in Anthropic messages
25pub const BASE64_TYPE: &str = "base64";
26pub const URL_TYPE: &str = "url";
27pub const EPHEMERAL_TYPE: &str = "ephemeral";
28pub const IMAGE_TYPE: &str = "image";
29pub const TEXT_TYPE: &str = "text";
30pub const DOCUMENT_TYPE: &str = "document";
31pub const DOCUMENT_BASE64_PDF_TYPE: &str = "application/pdf";
32pub const DOCUMENT_PLAIN_TEXT_TYPE: &str = "text/plain";
33pub const WEB_SEARCH_RESULT_TYPE: &str = "web_search_result";
34pub const SEARCH_TYPE: &str = "search_result";
35pub const THINKING_TYPE: &str = "thinking";
36pub const REDACTED_THINKING_TYPE: &str = "redacted_thinking";
37pub const TOOL_USE_TYPE: &str = "tool_use";
38pub const TOOL_RESULT_TYPE: &str = "tool_result";
39pub const WEB_SEARCH_TOOL_RESULT_TYPE: &str = "web_search_tool_result";
40pub const SERVER_TOOL_USE_TYPE: &str = "server_tool_use";
41
42// Citation type constants
43pub const CHAR_LOCATION_TYPE: &str = "char_location";
44pub const PAGE_LOCATION_TYPE: &str = "page_location";
45pub const CONTENT_BLOCK_LOCATION_TYPE: &str = "content_block_location";
46pub const WEB_SEARCH_RESULT_LOCATION_TYPE: &str = "web_search_result_location";
47pub const SEARCH_RESULT_LOCATION_TYPE: &str = "search_result_location";
48pub const WEB_SEARCH_TOOL_RESULT_ERROR_TYPE: &str = "web_search_tool_result_error";
49
50#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
51#[pyclass]
52pub struct CitationCharLocationParam {
53    #[pyo3(get, set)]
54    pub cited_text: String,
55    #[pyo3(get, set)]
56    pub document_index: i32,
57    #[pyo3(get, set)]
58    pub document_title: String,
59    #[pyo3(get, set)]
60    pub end_char_index: i32,
61    #[pyo3(get, set)]
62    pub start_char_index: i32,
63    #[pyo3(get)]
64    #[serde(rename = "type")]
65    pub r#type: String,
66}
67
68#[pymethods]
69impl CitationCharLocationParam {
70    #[new]
71    pub fn new(
72        cited_text: String,
73        document_index: i32,
74        document_title: String,
75        end_char_index: i32,
76        start_char_index: i32,
77    ) -> Self {
78        Self {
79            cited_text,
80            document_index,
81            document_title,
82            end_char_index,
83            start_char_index,
84            r#type: CHAR_LOCATION_TYPE.to_string(),
85        }
86    }
87}
88
89#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
90#[pyclass]
91pub struct CitationPageLocationParam {
92    #[pyo3(get, set)]
93    pub cited_text: String,
94    #[pyo3(get, set)]
95    pub document_index: i32,
96    #[pyo3(get, set)]
97    pub document_title: String,
98    #[pyo3(get, set)]
99    pub end_page_number: i32,
100    #[pyo3(get, set)]
101    pub start_page_number: i32,
102    #[pyo3(get)]
103    #[serde(rename = "type")]
104    pub r#type: String,
105}
106
107#[pymethods]
108impl CitationPageLocationParam {
109    #[new]
110    pub fn new(
111        cited_text: String,
112        document_index: i32,
113        document_title: String,
114        end_page_number: i32,
115        start_page_number: i32,
116    ) -> Self {
117        Self {
118            cited_text,
119            document_index,
120            document_title,
121            end_page_number,
122            start_page_number,
123            r#type: PAGE_LOCATION_TYPE.to_string(),
124        }
125    }
126}
127
128#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
129#[pyclass]
130pub struct CitationContentBlockLocationParam {
131    #[pyo3(get, set)]
132    pub cited_text: String,
133    #[pyo3(get, set)]
134    pub document_index: i32,
135    #[pyo3(get, set)]
136    pub document_title: String,
137    #[pyo3(get, set)]
138    pub end_block_index: i32,
139    #[pyo3(get, set)]
140    pub start_block_index: i32,
141    #[pyo3(get)]
142    #[serde(rename = "type")]
143    pub r#type: String,
144}
145
146#[pymethods]
147impl CitationContentBlockLocationParam {
148    #[new]
149    pub fn new(
150        cited_text: String,
151        document_index: i32,
152        document_title: String,
153        end_block_index: i32,
154        start_block_index: i32,
155    ) -> Self {
156        Self {
157            cited_text,
158            document_index,
159            document_title,
160            end_block_index,
161            start_block_index,
162            r#type: CONTENT_BLOCK_LOCATION_TYPE.to_string(),
163        }
164    }
165}
166
167#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
168#[pyclass]
169pub struct CitationWebSearchResultLocationParam {
170    #[pyo3(get, set)]
171    pub cited_text: String,
172    #[pyo3(get, set)]
173    pub encrypted_index: String,
174    #[pyo3(get, set)]
175    pub title: String,
176    #[pyo3(get)]
177    #[serde(rename = "type")]
178    pub r#type: String,
179    #[pyo3(get, set)]
180    pub url: String,
181}
182
183#[pymethods]
184impl CitationWebSearchResultLocationParam {
185    #[new]
186    pub fn new(cited_text: String, encrypted_index: String, title: String, url: String) -> Self {
187        Self {
188            cited_text,
189            encrypted_index,
190            title,
191            r#type: WEB_SEARCH_RESULT_LOCATION_TYPE.to_string(),
192            url,
193        }
194    }
195}
196
197#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
198#[pyclass]
199pub struct CitationSearchResultLocationParam {
200    #[pyo3(get, set)]
201    pub cited_text: String,
202    #[pyo3(get, set)]
203    pub end_block_index: i32,
204    #[pyo3(get, set)]
205    pub search_result_index: i32,
206    #[pyo3(get, set)]
207    pub source: String,
208    #[pyo3(get, set)]
209    pub start_block_index: i32,
210    #[pyo3(get, set)]
211    pub title: String,
212    #[pyo3(get)]
213    #[serde(rename = "type")]
214    pub r#type: String,
215}
216
217#[pymethods]
218impl CitationSearchResultLocationParam {
219    #[new]
220    pub fn new(
221        cited_text: String,
222        end_block_index: i32,
223        search_result_index: i32,
224        source: String,
225        start_block_index: i32,
226        title: String,
227    ) -> Self {
228        Self {
229            cited_text,
230            end_block_index,
231            search_result_index,
232            source,
233            start_block_index,
234            title,
235            r#type: SEARCH_RESULT_LOCATION_TYPE.to_string(),
236        }
237    }
238}
239
240/// Untagged enum for internal Rust usage - serializes without wrapper
241#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
242#[serde(untagged)]
243pub enum TextCitationParam {
244    CharLocation(CitationCharLocationParam),
245    PageLocation(CitationPageLocationParam),
246    ContentBlockLocation(CitationContentBlockLocationParam),
247    WebSearchResultLocation(CitationWebSearchResultLocationParam),
248    SearchResultLocation(CitationSearchResultLocationParam),
249}
250
251#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
252#[pyclass]
253pub struct TextBlockParam {
254    #[pyo3(get, set)]
255    pub text: String,
256    #[serde(skip_serializing_if = "Option::is_none")]
257    #[pyo3(get, set)]
258    pub cache_control: Option<CacheControl>,
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub citations: Option<TextCitationParam>,
261    #[pyo3(get)]
262    #[serde(rename = "type")]
263    pub r#type: String,
264}
265
266fn parse_text_citation(cit: &Bound<'_, PyAny>) -> Result<TextCitationParam, TypeError> {
267    if cit.is_instance_of::<CitationCharLocationParam>() {
268        Ok(TextCitationParam::CharLocation(
269            cit.extract::<CitationCharLocationParam>()?,
270        ))
271    } else if cit.is_instance_of::<CitationPageLocationParam>() {
272        Ok(TextCitationParam::PageLocation(
273            cit.extract::<CitationPageLocationParam>()?,
274        ))
275    } else if cit.is_instance_of::<CitationContentBlockLocationParam>() {
276        Ok(TextCitationParam::ContentBlockLocation(
277            cit.extract::<CitationContentBlockLocationParam>()?,
278        ))
279    } else if cit.is_instance_of::<CitationWebSearchResultLocationParam>() {
280        Ok(TextCitationParam::WebSearchResultLocation(
281            cit.extract::<CitationWebSearchResultLocationParam>()?,
282        ))
283    } else if cit.is_instance_of::<CitationSearchResultLocationParam>() {
284        Ok(TextCitationParam::SearchResultLocation(
285            cit.extract::<CitationSearchResultLocationParam>()?,
286        ))
287    } else {
288        Err(TypeError::InvalidInput(
289            "Invalid citation type provided".to_string(),
290        ))
291    }
292}
293#[pymethods]
294impl TextBlockParam {
295    #[new]
296    #[pyo3(signature = (text, cache_control=None, citations=None))]
297    pub fn new(
298        text: String,
299        cache_control: Option<CacheControl>,
300        citations: Option<&Bound<'_, PyAny>>,
301    ) -> Result<Self, TypeError> {
302        let citations = if let Some(cit) = citations {
303            Some(parse_text_citation(cit)?)
304        } else {
305            None
306        };
307        Ok(Self {
308            text,
309            cache_control,
310            citations,
311            r#type: TEXT_TYPE.to_string(),
312        })
313    }
314}
315
316impl TextBlockParam {
317    pub fn new_rs(
318        text: String,
319        cache_control: Option<CacheControl>,
320        citations: Option<TextCitationParam>,
321    ) -> Self {
322        Self {
323            text,
324            cache_control,
325            citations,
326            r#type: TEXT_TYPE.to_string(),
327        }
328    }
329}
330
331#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
332#[pyclass]
333pub struct Base64ImageSource {
334    #[pyo3(get, set)]
335    pub media_type: String,
336    #[pyo3(get, set)]
337    pub data: String,
338    #[pyo3(get)]
339    #[serde(rename = "type")]
340    pub r#type: String,
341}
342
343#[pymethods]
344impl Base64ImageSource {
345    #[new]
346    pub fn new(media_type: String, data: String) -> Result<Self, TypeError> {
347        // confirm media_type is an image type, otherwise raise error
348        if !get_image_media_types().contains(media_type.as_str()) {
349            return Err(TypeError::InvalidMediaType(media_type));
350        }
351        Ok(Self {
352            media_type,
353            data,
354            r#type: BASE64_TYPE.to_string(),
355        })
356    }
357}
358
359#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
360#[pyclass]
361pub struct UrlImageSource {
362    #[pyo3(get, set)]
363    pub url: String,
364    #[pyo3(get)]
365    #[serde(rename = "type")]
366    pub r#type: String,
367}
368
369#[pymethods]
370impl UrlImageSource {
371    #[new]
372    pub fn new(url: String) -> Self {
373        Self {
374            url,
375            r#type: URL_TYPE.to_string(),
376        }
377    }
378}
379
380#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
381#[serde(untagged)] // we need to strip serde type ref
382pub enum ImageSource {
383    Base64(Base64ImageSource),
384    Url(UrlImageSource),
385}
386
387#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
388#[pyclass]
389pub struct ImageBlockParam {
390    pub source: ImageSource,
391    #[serde(skip_serializing_if = "Option::is_none")]
392    #[pyo3(get, set)]
393    pub cache_control: Option<CacheControl>,
394    #[pyo3(get)]
395    #[serde(rename = "type")]
396    pub r#type: String,
397}
398
399#[pymethods]
400impl ImageBlockParam {
401    #[new]
402    #[pyo3(signature = (source, cache_control=None))]
403    pub fn new(
404        source: &Bound<'_, PyAny>,
405        cache_control: Option<CacheControl>,
406    ) -> Result<Self, TypeError> {
407        let source: ImageSource = if source.is_instance_of::<Base64ImageSource>() {
408            ImageSource::Base64(source.extract::<Base64ImageSource>()?)
409        } else {
410            ImageSource::Url(source.extract::<UrlImageSource>()?)
411        };
412        Ok(Self {
413            source,
414            cache_control,
415            r#type: IMAGE_TYPE.to_string(),
416        })
417    }
418
419    #[getter]
420    pub fn source<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
421        match &self.source {
422            ImageSource::Base64(base64) => {
423                let py_obj = base64.clone().into_bound_py_any(py)?;
424                Ok(py_obj.clone())
425            }
426            ImageSource::Url(url) => {
427                let py_obj = url.clone().into_bound_py_any(py)?;
428                Ok(py_obj.clone())
429            }
430        }
431    }
432}
433
434#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
435#[pyclass]
436pub struct Base64PDFSource {
437    #[pyo3(get, set)]
438    pub media_type: String,
439    #[pyo3(get, set)]
440    pub data: String,
441    #[pyo3(get)]
442    #[serde(rename = "type")]
443    pub r#type: String,
444}
445
446#[pymethods]
447impl Base64PDFSource {
448    #[new]
449    pub fn new(data: String) -> Result<Self, TypeError> {
450        Ok(Self {
451            media_type: DOCUMENT_BASE64_PDF_TYPE.to_string(),
452            data,
453            r#type: BASE64_TYPE.to_string(),
454        })
455    }
456}
457
458#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
459#[pyclass]
460pub struct UrlPDFSource {
461    #[pyo3(get, set)]
462    pub url: String,
463    #[pyo3(get)]
464    #[serde(rename = "type")]
465    pub r#type: String,
466}
467
468#[pymethods]
469impl UrlPDFSource {
470    #[new]
471    pub fn new(url: String) -> Self {
472        Self {
473            url,
474            r#type: URL_TYPE.to_string(),
475        }
476    }
477}
478
479#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
480#[pyclass]
481pub struct PlainTextSource {
482    #[pyo3(get, set)]
483    pub media_type: String,
484    #[pyo3(get, set)]
485    pub data: String,
486    #[pyo3(get)]
487    #[serde(rename = "type")]
488    pub r#type: String,
489}
490
491#[pymethods]
492impl PlainTextSource {
493    #[new]
494    pub fn new(data: String) -> Self {
495        Self {
496            media_type: DOCUMENT_PLAIN_TEXT_TYPE.to_string(),
497            data,
498            r#type: TEXT_TYPE.to_string(),
499        }
500    }
501}
502
503#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
504#[pyclass]
505pub struct CitationsConfigParams {
506    #[pyo3(get, set)]
507    pub enabled: Option<bool>,
508}
509
510#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
511#[serde(untagged)]
512pub enum DocumentSource {
513    Base64(Base64PDFSource),
514    Url(UrlPDFSource),
515    Text(PlainTextSource),
516}
517
518#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
519#[pyclass]
520pub struct DocumentBlockParam {
521    pub source: DocumentSource,
522    #[serde(skip_serializing_if = "Option::is_none")]
523    #[pyo3(get, set)]
524    pub cache_control: Option<CacheControl>,
525    #[serde(skip_serializing_if = "Option::is_none")]
526    #[pyo3(get, set)]
527    pub title: Option<String>,
528    #[serde(skip_serializing_if = "Option::is_none")]
529    #[pyo3(get, set)]
530    pub context: Option<String>,
531    #[serde(rename = "type")]
532    #[pyo3(get, set)]
533    pub r#type: String,
534    #[serde(skip_serializing_if = "Option::is_none")]
535    #[pyo3(get, set)]
536    pub citations: Option<CitationsConfigParams>,
537}
538
539#[pymethods]
540impl DocumentBlockParam {
541    #[new]
542    #[pyo3(signature = (source, cache_control=None, title=None, context=None, citations=None))]
543    pub fn new(
544        source: &Bound<'_, PyAny>,
545        cache_control: Option<CacheControl>,
546        title: Option<String>,
547        context: Option<String>,
548        citations: Option<CitationsConfigParams>,
549    ) -> Result<Self, TypeError> {
550        let source: DocumentSource = if source.is_instance_of::<Base64PDFSource>() {
551            DocumentSource::Base64(source.extract::<Base64PDFSource>()?)
552        } else if source.is_instance_of::<UrlPDFSource>() {
553            DocumentSource::Url(source.extract::<UrlPDFSource>()?)
554        } else {
555            DocumentSource::Text(source.extract::<PlainTextSource>()?)
556        };
557
558        Ok(Self {
559            source,
560            cache_control,
561            title,
562            context,
563            r#type: DOCUMENT_TYPE.to_string(),
564            citations,
565        })
566    }
567}
568
569#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
570#[pyclass]
571pub struct SearchResultBlockParam {
572    #[pyo3(get, set)]
573    pub content: Vec<TextBlockParam>,
574    #[pyo3(get, set)]
575    pub source: String,
576    #[pyo3(get, set)]
577    pub title: String,
578    #[serde(rename = "type")]
579    #[pyo3(get, set)]
580    pub r#type: String,
581    #[pyo3(get, set)]
582    pub cache_control: Option<CacheControl>,
583    #[pyo3(get, set)]
584    pub citations: Option<CitationsConfigParams>,
585}
586
587#[pymethods]
588impl SearchResultBlockParam {
589    #[new]
590    #[pyo3(signature = (content, source, title, cache_control=None, citations=None))]
591    pub fn new(
592        content: Vec<TextBlockParam>,
593        source: String,
594        title: String,
595        cache_control: Option<CacheControl>,
596        citations: Option<CitationsConfigParams>,
597    ) -> Self {
598        Self {
599            content,
600            source,
601            title,
602            r#type: SEARCH_TYPE.to_string(),
603            cache_control,
604            citations,
605        }
606    }
607}
608
609#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
610#[pyclass]
611pub struct ThinkingBlockParam {
612    #[pyo3(get, set)]
613    pub thinking: String,
614    #[pyo3(get, set)]
615    pub signature: Option<String>,
616    #[pyo3(get, set)]
617    #[serde(rename = "type")]
618    pub r#type: String,
619}
620
621#[pymethods]
622impl ThinkingBlockParam {
623    #[new]
624    #[pyo3(signature = (thinking, signature=None))]
625    pub fn new(thinking: String, signature: Option<String>) -> Self {
626        Self {
627            thinking,
628            signature,
629            r#type: THINKING_TYPE.to_string(),
630        }
631    }
632}
633
634#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
635#[pyclass]
636pub struct RedactedThinkingBlockParam {
637    #[pyo3(get, set)]
638    pub data: String,
639    #[pyo3(get, set)]
640    #[serde(rename = "type")]
641    pub r#type: String,
642}
643
644#[pymethods]
645impl RedactedThinkingBlockParam {
646    #[new]
647    pub fn new(data: String) -> Self {
648        Self {
649            data,
650            r#type: REDACTED_THINKING_TYPE.to_string(),
651        }
652    }
653}
654
655#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
656#[pyclass]
657pub struct ToolUseBlockParam {
658    #[pyo3(get, set)]
659    pub id: String,
660    #[pyo3(get, set)]
661    pub name: String,
662    pub input: Value,
663    #[serde(skip_serializing_if = "Option::is_none")]
664    #[pyo3(get, set)]
665    pub cache_control: Option<CacheControl>,
666    #[pyo3(get, set)]
667    #[serde(rename = "type")]
668    pub r#type: String,
669}
670
671#[pymethods]
672impl ToolUseBlockParam {
673    #[new]
674    #[pyo3(signature = (id, name, input, cache_control=None))]
675    pub fn new(
676        id: String,
677        name: String,
678        input: &Bound<'_, PyAny>,
679        cache_control: Option<CacheControl>,
680    ) -> Result<Self, TypeError> {
681        let input_value = depythonize(input)?;
682        Ok(Self {
683            id,
684            name,
685            input: input_value,
686            cache_control,
687            r#type: TOOL_USE_TYPE.to_string(),
688        })
689    }
690
691    #[getter]
692    pub fn input<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
693        let py_dict = pythonize(py, &self.input)?;
694        Ok(py_dict)
695    }
696}
697
698#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
699#[serde(untagged)]
700pub enum ToolResultContentEnum {
701    Text(Vec<TextBlockParam>),
702    Image(Vec<ImageBlockParam>),
703    Document(Vec<DocumentBlockParam>),
704    SearchResult(Vec<SearchResultBlockParam>),
705}
706
707#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
708#[pyclass]
709pub struct ToolResultBlockParam {
710    #[pyo3(get, set)]
711    pub tool_use_id: String,
712    #[serde(skip_serializing_if = "Option::is_none")]
713    pub is_error: Option<bool>,
714    #[serde(skip_serializing_if = "Option::is_none")]
715    #[pyo3(get, set)]
716    pub cache_control: Option<CacheControl>,
717    #[pyo3(get, set)]
718    #[serde(rename = "type")]
719    pub r#type: String,
720    #[serde(skip_serializing_if = "Option::is_none")]
721    pub content: Option<ToolResultContentEnum>,
722}
723
724/// Helper function to extract all blocks of a specific type
725fn extract_all_blocks<T>(blocks: Vec<Bound<'_, PyAny>>) -> Result<Vec<T>, TypeError>
726where
727    T: for<'a, 'py> FromPyObject<'a, 'py>,
728{
729    blocks
730        .into_iter()
731        .map(|block| {
732            block
733                .extract::<T>()
734                .map_err(|_| TypeError::Error("Failed to extract block".to_string()))
735        })
736        .collect()
737}
738
739#[pymethods]
740impl ToolResultBlockParam {
741    #[new]
742    #[pyo3(signature = (
743        tool_use_id,
744        is_error=None,
745        cache_control=None,
746        content=None
747    ))]
748    pub fn new(
749        tool_use_id: String,
750        is_error: Option<bool>,
751        cache_control: Option<CacheControl>,
752        content: Option<Vec<Bound<'_, PyAny>>>,
753    ) -> Result<Self, TypeError> {
754        let content_enum = match content {
755            None => None,
756            Some(blocks) if blocks.is_empty() => None,
757
758            Some(blocks) => {
759                let first_block = &blocks[0];
760
761                if first_block.is_instance_of::<TextBlockParam>() {
762                    Some(ToolResultContentEnum::Text(extract_all_blocks(blocks)?))
763                } else if first_block.is_instance_of::<ImageBlockParam>() {
764                    Some(ToolResultContentEnum::Image(extract_all_blocks(blocks)?))
765                } else if first_block.is_instance_of::<DocumentBlockParam>() {
766                    Some(ToolResultContentEnum::Document(extract_all_blocks(blocks)?))
767                } else if first_block.is_instance_of::<SearchResultBlockParam>() {
768                    Some(ToolResultContentEnum::SearchResult(extract_all_blocks(
769                        blocks,
770                    )?))
771                } else {
772                    return Err(TypeError::InvalidInput(
773                        "Unsupported content block type".to_string(),
774                    ));
775                }
776            }
777        };
778
779        Ok(Self {
780            tool_use_id,
781            is_error,
782            cache_control,
783            r#type: TOOL_RESULT_TYPE.to_string(),
784            content: content_enum,
785        })
786    }
787
788    #[getter]
789    pub fn content<'py>(&self, py: Python<'py>) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
790        match &self.content {
791            None => Ok(None),
792            Some(ToolResultContentEnum::Text(blocks)) => {
793                let py_list = blocks
794                    .iter()
795                    .map(|block| block.clone().into_bound_py_any(py))
796                    .collect::<Result<Vec<_>, _>>()?;
797                Ok(Some(py_list.into_bound_py_any(py)?))
798            }
799            Some(ToolResultContentEnum::Image(blocks)) => {
800                let py_list = blocks
801                    .iter()
802                    .map(|block| block.clone().into_bound_py_any(py))
803                    .collect::<Result<Vec<_>, _>>()?;
804                Ok(Some(py_list.into_bound_py_any(py)?))
805            }
806            Some(ToolResultContentEnum::Document(blocks)) => {
807                let py_list = blocks
808                    .iter()
809                    .map(|block| block.clone().into_bound_py_any(py))
810                    .collect::<Result<Vec<_>, _>>()?;
811                Ok(Some(py_list.into_bound_py_any(py)?))
812            }
813            Some(ToolResultContentEnum::SearchResult(blocks)) => {
814                let py_list = blocks
815                    .iter()
816                    .map(|block| block.clone().into_bound_py_any(py))
817                    .collect::<Result<Vec<_>, _>>()?;
818                Ok(Some(py_list.into_bound_py_any(py)?))
819            }
820        }
821    }
822}
823
824#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
825#[pyclass]
826pub struct ServerToolUseBlockParam {
827    #[pyo3(get, set)]
828    pub id: String,
829    #[pyo3(get, set)]
830    pub name: String,
831    pub input: Value,
832    #[serde(skip_serializing_if = "Option::is_none")]
833    pub cache_control: Option<CacheControl>,
834    #[pyo3(get, set)]
835    #[serde(rename = "type")]
836    pub r#type: String,
837}
838
839#[pymethods]
840impl ServerToolUseBlockParam {
841    #[new]
842    #[pyo3(signature = (id, name, input, cache_control=None))]
843    pub fn new(
844        id: String,
845        name: String,
846        input: &Bound<'_, PyAny>,
847        cache_control: Option<CacheControl>,
848    ) -> Result<Self, TypeError> {
849        let input_value = depythonize(input)?;
850        Ok(Self {
851            id,
852            name,
853            input: input_value,
854            cache_control,
855            r#type: SERVER_TOOL_USE_TYPE.to_string(),
856        })
857    }
858    #[getter]
859    pub fn input<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
860        let py_dict = pythonize(py, &self.input)?;
861        Ok(py_dict)
862    }
863}
864
865#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
866#[pyclass]
867pub struct WebSearchResultBlockParam {
868    #[pyo3(get, set)]
869    pub encrypted_content: String,
870    #[pyo3(get, set)]
871    pub title: String,
872    #[pyo3(get, set)]
873    pub url: String,
874    #[pyo3(get, set)]
875    pub page_agent: Option<String>,
876    #[pyo3(get, set)]
877    #[serde(rename = "type")]
878    pub r#type: String,
879}
880
881#[pymethods]
882impl WebSearchResultBlockParam {
883    #[new]
884    #[pyo3(signature = (encrypted_content, title, url, page_agent=None))]
885    pub fn new(
886        encrypted_content: String,
887        title: String,
888        url: String,
889        page_agent: Option<String>,
890    ) -> Self {
891        Self {
892            encrypted_content,
893            title,
894            url,
895            page_agent,
896            r#type: WEB_SEARCH_RESULT_TYPE.to_string(),
897        }
898    }
899}
900
901#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
902#[pyclass]
903pub struct WebSearchToolResultBlockParam {
904    #[pyo3(get, set)]
905    pub tool_use_id: String,
906    #[pyo3(get, set)]
907    pub content: Vec<WebSearchResultBlockParam>,
908    #[serde(skip_serializing_if = "Option::is_none")]
909    #[pyo3(get, set)]
910    pub cache_control: Option<CacheControl>,
911    #[pyo3(get, set)]
912    #[serde(rename = "type")]
913    pub r#type: String,
914}
915
916#[pymethods]
917impl WebSearchToolResultBlockParam {
918    #[new]
919    #[pyo3(signature = (tool_use_id, content, cache_control=None))]
920    pub fn new(
921        tool_use_id: String,
922        content: Vec<WebSearchResultBlockParam>,
923        cache_control: Option<CacheControl>,
924    ) -> Self {
925        Self {
926            tool_use_id,
927            content,
928            cache_control,
929            r#type: WEB_SEARCH_TOOL_RESULT_TYPE.to_string(),
930        }
931    }
932}
933
934#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
935#[serde(untagged)]
936pub(crate) enum ContentBlock {
937    Text(TextBlockParam),
938    Image(ImageBlockParam),
939    Document(DocumentBlockParam),
940    SearchResult(SearchResultBlockParam),
941    Thinking(ThinkingBlockParam),
942    RedactedThinking(RedactedThinkingBlockParam),
943    ToolUse(ToolUseBlockParam),
944    ToolResult(ToolResultBlockParam),
945    ServerToolUse(ServerToolUseBlockParam),
946    WebSearchResult(WebSearchResultBlockParam),
947}
948
949macro_rules! try_extract_content_block {
950    ($block:expr, $($variant:ident => $type:ty),+ $(,)?) => {{
951        $(
952            if $block.is_instance_of::<$type>() {
953                return Ok(Self {
954                    inner: ContentBlock::$variant($block.extract::<$type>()?),
955                });
956            }
957        )+
958        return Err(TypeError::InvalidInput(
959            "Unsupported content block type".to_string(),
960        ));
961    }};
962}
963
964#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
965pub struct ContentBlockParam {
966    #[serde(flatten)]
967    pub(crate) inner: ContentBlock,
968}
969
970impl ContentBlockParam {
971    pub fn new(block: &Bound<'_, PyAny>) -> Result<Self, TypeError> {
972        try_extract_content_block!(
973            block,
974            Text => TextBlockParam,
975            Image => ImageBlockParam,
976            Document => DocumentBlockParam,
977            SearchResult => SearchResultBlockParam,
978            Thinking => ThinkingBlockParam,
979            RedactedThinking => RedactedThinkingBlockParam,
980            ToolUse => ToolUseBlockParam,
981            ToolResult => ToolResultBlockParam,
982            ServerToolUse => ServerToolUseBlockParam,
983            WebSearchResult => WebSearchResultBlockParam,
984        )
985    }
986
987    /// Convert the ContentBlockParam back to a PyObject
988    /// This is an acceptable clone, as this will really only be used in development/testing scenarios
989    pub fn to_pyobject<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
990        match &self.inner {
991            ContentBlock::Text(block) => Ok(block.clone().into_bound_py_any(py)?.clone()),
992            ContentBlock::Image(block) => Ok(block.clone().into_bound_py_any(py)?.clone()),
993            ContentBlock::Document(block) => Ok(block.clone().into_bound_py_any(py)?.clone()),
994            ContentBlock::SearchResult(block) => Ok(block.clone().into_bound_py_any(py)?.clone()),
995            ContentBlock::Thinking(block) => Ok(block.clone().into_bound_py_any(py)?.clone()),
996            ContentBlock::RedactedThinking(block) => {
997                Ok(block.clone().into_bound_py_any(py)?.clone())
998            }
999            ContentBlock::ToolUse(block) => Ok(block.clone().into_bound_py_any(py)?.clone()),
1000            ContentBlock::ToolResult(block) => Ok(block.clone().into_bound_py_any(py)?.clone()),
1001            ContentBlock::ServerToolUse(block) => Ok(block.clone().into_bound_py_any(py)?.clone()),
1002            ContentBlock::WebSearchResult(block) => {
1003                Ok(block.clone().into_bound_py_any(py)?.clone())
1004            }
1005        }
1006    }
1007}
1008
1009impl ContentBlockParam {
1010    /// Helper function to create a ContentBlockParam from a ResponseContentBlockInner
1011    ///
1012    /// Converts response content blocks (from API responses) into request content blocks
1013    /// that can be used in subsequent requests. This is useful for multi-turn conversations
1014    /// where the assistant's response needs to be included in the next request.
1015    ///
1016    /// # Arguments
1017    /// * `block` - A reference to a ResponseContentBlockInner from an API response
1018    ///
1019    /// # Returns
1020    /// * `Result<Self, TypeError>` - The converted ContentBlockParam or an error
1021    pub(crate) fn from_response_content_block(
1022        block: &ResponseContentBlockInner,
1023    ) -> Result<Self, TypeError> {
1024        match block {
1025            ResponseContentBlockInner::Text(text_block) => {
1026                // Convert Vec<TextCitation> to Option<TextCitationParam>
1027                // Take the first citation if available, as request format uses singular citation
1028                let citations = text_block.citations.as_ref().and_then(|cits| {
1029                    cits.first().map(|cit| match cit {
1030                        TextCitation::CharLocation(c) => {
1031                            TextCitationParam::CharLocation(CitationCharLocationParam {
1032                                cited_text: c.cited_text.clone(),
1033                                document_index: c.document_index,
1034                                document_title: c.document_title.clone(),
1035                                end_char_index: c.end_char_index,
1036                                start_char_index: c.start_char_index,
1037                                r#type: c.r#type.clone(),
1038                            })
1039                        }
1040                        TextCitation::PageLocation(c) => {
1041                            TextCitationParam::PageLocation(CitationPageLocationParam {
1042                                cited_text: c.cited_text.clone(),
1043                                document_index: c.document_index,
1044                                document_title: c.document_title.clone(),
1045                                end_page_number: c.end_page_number,
1046                                start_page_number: c.start_page_number,
1047                                r#type: c.r#type.clone(),
1048                            })
1049                        }
1050                        TextCitation::ContentBlockLocation(c) => {
1051                            TextCitationParam::ContentBlockLocation(
1052                                CitationContentBlockLocationParam {
1053                                    cited_text: c.cited_text.clone(),
1054                                    document_index: c.document_index,
1055                                    document_title: c.document_title.clone(),
1056                                    end_block_index: c.end_block_index,
1057                                    start_block_index: c.start_block_index,
1058                                    r#type: c.r#type.clone(),
1059                                },
1060                            )
1061                        }
1062                        TextCitation::WebSearchResultLocation(c) => {
1063                            TextCitationParam::WebSearchResultLocation(
1064                                CitationWebSearchResultLocationParam {
1065                                    cited_text: c.cited_text.clone(),
1066                                    encrypted_index: c.encrypted_index.clone(),
1067                                    title: c.title.clone(),
1068                                    r#type: c.r#type.clone(),
1069                                    url: c.url.clone(),
1070                                },
1071                            )
1072                        }
1073                        TextCitation::SearchResultLocation(c) => {
1074                            TextCitationParam::SearchResultLocation(
1075                                CitationSearchResultLocationParam {
1076                                    cited_text: c.cited_text.clone(),
1077                                    end_block_index: c.end_block_index,
1078                                    search_result_index: c.search_result_index,
1079                                    source: c.source.clone(),
1080                                    start_block_index: c.start_block_index,
1081                                    title: c.title.clone(),
1082                                    r#type: c.r#type.clone(),
1083                                },
1084                            )
1085                        }
1086                    })
1087                });
1088
1089                Ok(Self {
1090                    inner: ContentBlock::Text(TextBlockParam {
1091                        text: text_block.text.clone(),
1092                        cache_control: None,
1093                        citations,
1094                        r#type: text_block.r#type.clone(),
1095                    }),
1096                })
1097            }
1098            ResponseContentBlockInner::Thinking(thinking_block) => Ok(Self {
1099                inner: ContentBlock::Thinking(ThinkingBlockParam {
1100                    thinking: thinking_block.thinking.clone(),
1101                    signature: thinking_block.signature.clone(),
1102                    r#type: thinking_block.r#type.clone(),
1103                }),
1104            }),
1105            ResponseContentBlockInner::RedactedThinking(redacted_thinking_block) => Ok(Self {
1106                inner: ContentBlock::RedactedThinking(RedactedThinkingBlockParam {
1107                    data: redacted_thinking_block.data.clone(),
1108                    r#type: redacted_thinking_block.r#type.clone(),
1109                }),
1110            }),
1111            ResponseContentBlockInner::ToolUse(tool_use_block) => Ok(Self {
1112                inner: ContentBlock::ToolUse(ToolUseBlockParam {
1113                    id: tool_use_block.id.clone(),
1114                    name: tool_use_block.name.clone(),
1115                    input: tool_use_block.input.clone(),
1116                    cache_control: None,
1117                    r#type: tool_use_block.r#type.clone(),
1118                }),
1119            }),
1120            ResponseContentBlockInner::ServerToolUse(server_tool_use_block) => Ok(Self {
1121                inner: ContentBlock::ServerToolUse(ServerToolUseBlockParam {
1122                    id: server_tool_use_block.id.clone(),
1123                    name: server_tool_use_block.name.clone(),
1124                    input: server_tool_use_block.input.clone(),
1125                    cache_control: None,
1126                    r#type: server_tool_use_block.r#type.clone(),
1127                }),
1128            }),
1129            ResponseContentBlockInner::WebSearchToolResult(web_search_tool_result_block) => {
1130                match &web_search_tool_result_block.content {
1131                    WebSearchToolResultBlockContent::Results(results) => {
1132                        // Take the first result and convert it to WebSearchResultBlockParam
1133                        let first_result = results.first().ok_or_else(|| {
1134                            TypeError::InvalidInput(
1135                                "WebSearchToolResult must contain at least one result".to_string(),
1136                            )
1137                        })?;
1138
1139                        Ok(Self {
1140                            inner: ContentBlock::WebSearchResult(WebSearchResultBlockParam {
1141                                encrypted_content: first_result.encrypted_content.clone(),
1142                                title: first_result.title.clone(),
1143                                url: first_result.url.clone(),
1144                                page_agent: first_result.page_age.clone(), // Note: page_age -> page_agent mapping
1145                                r#type: first_result.r#type.clone(),
1146                            }),
1147                        })
1148                    }
1149                    WebSearchToolResultBlockContent::Error(_) => Err(TypeError::InvalidInput(
1150                        "Cannot convert WebSearchToolResult error to ContentBlockParam".to_string(),
1151                    )),
1152                }
1153            }
1154        }
1155    }
1156}
1157
1158#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1159#[pyclass]
1160pub struct MessageParam {
1161    pub content: Vec<ContentBlockParam>,
1162    #[pyo3(get)]
1163    pub role: String,
1164}
1165
1166#[pymethods]
1167impl MessageParam {
1168    /// Create a new MessageParam
1169    /// Will accept either a string, Single type of ContentBlockParam, or a list of varying ContentBlockParams
1170    /// Initial logic determines initial type and processes accordingly
1171    /// If string, wraps in TextBlockParam --> ContentBlockParam
1172    /// If single param, pass to ContentBlockParam directly
1173    /// If list, process each item in list to ContentBlockParam
1174    #[new]
1175    pub fn new(content: &Bound<'_, PyAny>, role: String) -> Result<Self, TypeError> {
1176        let content: Vec<ContentBlockParam> = if content.is_instance_of::<pyo3::types::PyString>() {
1177            let text = content.extract::<String>()?;
1178            let text_block = TextBlockParam::new(text, None, None)?;
1179            let content_block =
1180                ContentBlockParam::new(&text_block.into_bound_py_any(content.py())?)?;
1181            vec![content_block]
1182        } else if content.is_instance_of::<pyo3::types::PyList>() {
1183            let content_list = content.extract::<Vec<Bound<'_, PyAny>>>()?;
1184            let mut blocks = Vec::new();
1185            for item in content_list {
1186                let content_block = ContentBlockParam::new(&item)?;
1187                blocks.push(content_block);
1188            }
1189            blocks
1190        } else {
1191            // pass single ContentBlockParam
1192            let content_block = ContentBlockParam::new(content)?;
1193            vec![content_block]
1194        };
1195
1196        Ok(Self { content, role })
1197    }
1198
1199    #[getter]
1200    fn content<'py>(&self, py: Python<'py>) -> Result<Vec<Bound<'py, PyAny>>, TypeError> {
1201        self.content
1202            .iter()
1203            .map(|block| block.to_pyobject(py))
1204            .collect()
1205    }
1206
1207    #[pyo3(name = "bind")]
1208    fn bind_py(&self, name: &str, value: &str) -> Result<Self, TypeError> {
1209        self.bind(name, value)
1210    }
1211
1212    #[pyo3(name = "bind_mut")]
1213    fn bind_mut_py(&mut self, name: &str, value: &str) -> Result<(), TypeError> {
1214        self.bind_mut(name, value)
1215    }
1216
1217    pub fn __str__(&self) -> String {
1218        PyHelperFuncs::__str__(self)
1219    }
1220
1221    pub fn model_dump<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1222        // iterate over each field in model_settings and add to the dict if it is not None
1223        let json = serde_json::to_value(self)?;
1224        Ok(pythonize(py, &json)?)
1225    }
1226
1227    // Return the text content from the first content part that is text
1228    #[getter]
1229    pub fn text(&self) -> String {
1230        self.content
1231            .iter()
1232            .find_map(|part| {
1233                if let ContentBlock::Text(text_part) = &part.inner {
1234                    Some(text_part.text.clone())
1235                } else {
1236                    None
1237                }
1238            })
1239            .unwrap_or_default()
1240    }
1241}
1242
1243impl PromptMessageExt for MessageParam {
1244    fn bind_mut(&mut self, name: &str, value: &str) -> Result<(), TypeError> {
1245        let regex = get_var_regex();
1246
1247        for part in &mut self.content {
1248            if let ContentBlock::Text(text_part) = &mut part.inner {
1249                text_part.text = regex
1250                    .replace_all(&text_part.text, |caps: &regex::Captures| {
1251                        if &caps[1] == name {
1252                            value.to_string()
1253                        } else {
1254                            caps.get(0)
1255                                .map_or_else(|| "".to_string(), |m| m.as_str().to_string())
1256                        }
1257                    })
1258                    .to_string();
1259            }
1260        }
1261
1262        Ok(())
1263    }
1264
1265    fn bind(&self, name: &str, value: &str) -> Result<Self, TypeError>
1266    where
1267        Self: Sized,
1268    {
1269        let mut new_message = self.clone();
1270        new_message.bind_mut(name, value)?;
1271        Ok(new_message)
1272    }
1273
1274    fn extract_variables(&self) -> Vec<String> {
1275        let mut variables = HashSet::new();
1276
1277        // Lazily initialize regex to avoid recompilation
1278        let regex = get_var_regex();
1279
1280        // Extract variables from all text content parts
1281        for part in &self.content {
1282            if let ContentBlock::Text(text_part) = &part.inner {
1283                for captures in regex.captures_iter(&text_part.text) {
1284                    if let Some(name) = captures.get(1) {
1285                        variables.insert(name.as_str().to_string());
1286                    }
1287                }
1288            }
1289        }
1290
1291        // Convert HashSet to Vec for return
1292        variables.into_iter().collect()
1293    }
1294
1295    fn from_text(content: String, role: &str) -> Result<Self, TypeError> {
1296        Ok(Self {
1297            role: role.to_string(),
1298            content: vec![ContentBlockParam {
1299                inner: ContentBlock::Text(TextBlockParam::new_rs(content, None, None)),
1300            }],
1301        })
1302    }
1303}
1304
1305impl MessageParam {
1306    /// Helper function to create a MessageParam from a single TextBlockParam
1307    pub fn to_text_block_param(&self) -> Result<TextBlockParam, TypeError> {
1308        if self.content.len() != 1 {
1309            return Err(TypeError::InvalidInput(
1310                "MessageParam must contain exactly one content block to convert to TextBlockParam"
1311                    .to_string(),
1312            ));
1313        }
1314
1315        match &self.content[0].inner {
1316            ContentBlock::Text(text_block) => Ok(text_block.clone()),
1317            _ => Err(TypeError::InvalidInput(
1318                "Content block is not of type TextBlockParam".to_string(),
1319            )),
1320        }
1321    }
1322}
1323
1324impl MessageFactory for MessageParam {
1325    fn from_text(content: String, role: &str) -> Result<Self, TypeError> {
1326        let text_block = TextBlockParam::new_rs(content, None, None);
1327        let content_block = ContentBlockParam {
1328            inner: ContentBlock::Text(text_block),
1329        };
1330
1331        Ok(Self {
1332            role: role.to_string(),
1333            content: vec![content_block],
1334        })
1335    }
1336}
1337
1338impl MessageConversion for MessageParam {
1339    fn to_anthropic_message(&self) -> Result<Self, TypeError> {
1340        // Currently, MessageParam is already in the Anthropic format
1341        Err(TypeError::CantConvertSelf)
1342    }
1343
1344    fn to_google_message(
1345        &self,
1346    ) -> Result<crate::google::v1::generate::request::GeminiContent, TypeError> {
1347        // Extract text content from all text blocks
1348        let mut parts = Vec::new();
1349
1350        for content_block in &self.content {
1351            match &content_block.inner {
1352                ContentBlock::Text(text_block) => {
1353                    parts.push(Part {
1354                        data: DataNum::Text(text_block.text.clone()),
1355                        thought: None,
1356                        thought_signature: None,
1357                        part_metadata: None,
1358                        media_resolution: None,
1359                        video_metadata: None,
1360                    });
1361                }
1362                _ => {
1363                    return Err(TypeError::UnsupportedConversion(
1364                        "Only text content blocks are currently supported for conversion"
1365                            .to_string(),
1366                    ));
1367                }
1368            }
1369        }
1370
1371        if parts.is_empty() {
1372            return Err(TypeError::UnsupportedConversion(
1373                "Message contains no text content to convert".to_string(),
1374            ));
1375        }
1376
1377        Ok(GeminiContent {
1378            role: self.role.clone(),
1379            parts,
1380        })
1381    }
1382
1383    fn to_openai_message(
1384        &self,
1385    ) -> Result<crate::openai::v1::chat::request::ChatMessage, TypeError> {
1386        // Extract text content from all text blocks
1387        let mut content_parts = Vec::new();
1388
1389        for content_block in &self.content {
1390            match &content_block.inner {
1391                ContentBlock::Text(text_block) => {
1392                    content_parts.push(ContentPart::Text(TextContentPart::new(
1393                        text_block.text.clone(),
1394                    )));
1395                }
1396                _ => {
1397                    return Err(TypeError::UnsupportedConversion(
1398                        "Only text content blocks are currently supported for conversion"
1399                            .to_string(),
1400                    ));
1401                }
1402            }
1403        }
1404
1405        if content_parts.is_empty() {
1406            return Err(TypeError::UnsupportedConversion(
1407                "Message contains no text content to convert".to_string(),
1408            ));
1409        }
1410
1411        Ok(ChatMessage {
1412            role: self.role.clone(),
1413            content: content_parts,
1414            name: None,
1415        })
1416    }
1417}
1418
1419#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1420#[pyclass]
1421pub struct Metadata {
1422    #[serde(skip_serializing_if = "Option::is_none")]
1423    pub user_id: Option<String>,
1424}
1425
1426#[pymethods]
1427impl Metadata {
1428    #[new]
1429    #[pyo3(signature = (user_id=None))]
1430    pub fn new(user_id: Option<String>) -> Self {
1431        Self { user_id }
1432    }
1433}
1434
1435#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1436#[pyclass]
1437pub struct CacheControl {
1438    #[serde(rename = "type")]
1439    pub cache_type: String, // "ephemeral"
1440    #[serde(skip_serializing_if = "Option::is_none")]
1441    pub ttl: Option<String>, // "5m" or "1h"
1442}
1443
1444#[pymethods]
1445impl CacheControl {
1446    #[new]
1447    #[pyo3(signature = (cache_type, ttl=None))]
1448    pub fn new(cache_type: String, ttl: Option<String>) -> Self {
1449        Self { cache_type, ttl }
1450    }
1451}
1452
1453#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1454#[pyclass(name = "AnthropicTool")]
1455pub struct Tool {
1456    pub name: String,
1457    #[serde(skip_serializing_if = "Option::is_none")]
1458    pub description: Option<String>,
1459    pub input_schema: Value,
1460    #[serde(skip_serializing_if = "Option::is_none")]
1461    pub cache_control: Option<CacheControl>,
1462}
1463
1464#[pymethods]
1465impl Tool {
1466    #[new]
1467    #[pyo3(signature = (
1468        name,
1469        input_schema,
1470        description=None,
1471        cache_control=None
1472    ))]
1473    pub fn new(
1474        name: String,
1475        input_schema: &Bound<'_, PyAny>,
1476        description: Option<String>,
1477        cache_control: Option<CacheControl>,
1478    ) -> Result<Self, UtilError> {
1479        Ok(Self {
1480            name,
1481            description,
1482            input_schema: depythonize(input_schema)?,
1483            cache_control,
1484        })
1485    }
1486}
1487
1488impl Tool {
1489    pub fn from_tool_agent_tool_definition(tool: &AgentToolDefinition) -> Result<Self, TypeError> {
1490        Ok(Self {
1491            name: tool.name.clone(),
1492            description: Some(tool.description.clone()),
1493            input_schema: tool.parameters.clone(),
1494            cache_control: None,
1495        })
1496    }
1497}
1498
1499#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1500#[pyclass(name = "AnthropicThinkingConfig")]
1501pub struct ThinkingConfig {
1502    #[pyo3(get)]
1503    pub r#type: String,
1504
1505    #[serde(skip_serializing_if = "Option::is_none")]
1506    #[pyo3(get)]
1507    pub budget_tokens: Option<i32>,
1508}
1509
1510#[pymethods]
1511impl ThinkingConfig {
1512    #[new]
1513    #[pyo3(signature = (r#type, budget_tokens=None))]
1514    pub fn new(r#type: String, budget_tokens: Option<i32>) -> Self {
1515        Self {
1516            r#type,
1517            budget_tokens,
1518        }
1519    }
1520}
1521
1522#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1523#[pyclass(name = "AnthropicToolChoice")]
1524pub struct ToolChoice {
1525    #[pyo3(get)]
1526    pub r#type: String, // "auto", "any", "tool", "none"
1527
1528    #[serde(skip_serializing_if = "Option::is_none")]
1529    #[pyo3(get)]
1530    disable_parallel_tool_use: Option<bool>,
1531
1532    #[serde(skip_serializing_if = "Option::is_none")]
1533    #[pyo3(get)]
1534    pub name: Option<String>,
1535}
1536
1537#[pymethods]
1538impl ToolChoice {
1539    #[new]
1540    #[pyo3(signature = (r#type, disable_parallel_tool_use=None, name=None))]
1541    pub fn new(
1542        r#type: String,
1543        disable_parallel_tool_use: Option<bool>,
1544        name: Option<String>,
1545    ) -> Result<Self, UtilError> {
1546        match name {
1547            Some(_) if r#type != "tool" => {
1548                return Err(UtilError::PyError(
1549                    "ToolChoice name can only be set if type is 'tool'".to_string(),
1550                ))
1551            }
1552            None if r#type == "tool" => {
1553                return Err(UtilError::PyError(
1554                    "ToolChoice of type 'tool' requires a name".to_string(),
1555                ))
1556            }
1557            _ => {}
1558        }
1559
1560        Ok(Self {
1561            r#type,
1562            disable_parallel_tool_use,
1563            name,
1564        })
1565    }
1566}
1567
1568#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1569#[pyclass]
1570#[serde(default)]
1571pub struct AnthropicSettings {
1572    #[pyo3(get)]
1573    pub max_tokens: i32,
1574
1575    #[serde(skip_serializing_if = "Option::is_none")]
1576    #[pyo3(get)]
1577    pub metadata: Option<Metadata>,
1578
1579    #[serde(skip_serializing_if = "Option::is_none")]
1580    #[pyo3(get)]
1581    pub service_tier: Option<String>,
1582
1583    #[serde(skip_serializing_if = "Option::is_none")]
1584    #[pyo3(get)]
1585    pub stop_sequences: Option<Vec<String>>,
1586
1587    #[serde(skip_serializing_if = "Option::is_none")]
1588    #[pyo3(get)]
1589    pub stream: Option<bool>,
1590
1591    #[serde(skip_serializing_if = "Option::is_none")]
1592    #[pyo3(get)]
1593    pub system: Option<String>,
1594
1595    #[serde(skip_serializing_if = "Option::is_none")]
1596    #[pyo3(get)]
1597    pub temperature: Option<f32>,
1598
1599    #[serde(skip_serializing_if = "Option::is_none")]
1600    #[pyo3(get)]
1601    pub thinking: Option<ThinkingConfig>,
1602
1603    #[serde(skip_serializing_if = "Option::is_none")]
1604    #[pyo3(get)]
1605    pub tool_choice: Option<ToolChoice>,
1606
1607    #[serde(skip_serializing_if = "Option::is_none")]
1608    #[pyo3(get)]
1609    pub tools: Option<Vec<Tool>>,
1610
1611    #[serde(skip_serializing_if = "Option::is_none")]
1612    #[pyo3(get)]
1613    pub top_k: Option<i32>,
1614
1615    #[serde(skip_serializing_if = "Option::is_none")]
1616    #[pyo3(get)]
1617    pub top_p: Option<f32>,
1618
1619    #[serde(skip_serializing_if = "Option::is_none")]
1620    pub extra_body: Option<Value>,
1621}
1622
1623impl Default for AnthropicSettings {
1624    fn default() -> Self {
1625        Self {
1626            max_tokens: 4096,
1627            metadata: None,
1628            service_tier: None,
1629            stop_sequences: None,
1630            stream: Some(false),
1631            system: None,
1632            temperature: None,
1633            thinking: None,
1634            top_k: None,
1635            top_p: None,
1636            tools: None,
1637            tool_choice: None,
1638            extra_body: None,
1639        }
1640    }
1641}
1642
1643#[pymethods]
1644impl AnthropicSettings {
1645    #[new]
1646    #[pyo3(signature = (
1647        max_tokens=4096,
1648        metadata=None,
1649        service_tier=None,
1650        stop_sequences=None,
1651        stream=None,
1652        system =None,
1653        temperature=None,
1654        thinking=None,
1655        top_k=None,
1656        top_p=None,
1657        tools=None,
1658        tool_choice=None,
1659        extra_body=None
1660    ))]
1661    #[allow(clippy::too_many_arguments)]
1662    pub fn new(
1663        max_tokens: i32,
1664        metadata: Option<Metadata>,
1665        service_tier: Option<String>,
1666        stop_sequences: Option<Vec<String>>,
1667        stream: Option<bool>,
1668        system: Option<String>,
1669        temperature: Option<f32>,
1670        thinking: Option<ThinkingConfig>,
1671        top_k: Option<i32>,
1672        top_p: Option<f32>,
1673        tools: Option<Vec<Tool>>,
1674        tool_choice: Option<ToolChoice>,
1675        extra_body: Option<&Bound<'_, PyAny>>,
1676    ) -> Result<Self, UtilError> {
1677        let extra = match extra_body {
1678            Some(obj) => Some(depythonize(obj)?),
1679            None => None,
1680        };
1681
1682        Ok(Self {
1683            max_tokens,
1684            metadata,
1685            service_tier,
1686            stop_sequences,
1687            stream,
1688            system,
1689            temperature,
1690            thinking,
1691            top_k,
1692            top_p,
1693            tools,
1694            tool_choice,
1695            extra_body: extra,
1696        })
1697    }
1698
1699    pub fn __str__(&self) -> String {
1700        PyHelperFuncs::__str__(self)
1701    }
1702
1703    pub fn model_dump<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1704        // iterate over each field in model_settings and add to the dict if it is not None
1705        let json = serde_json::to_value(self)?;
1706        Ok(pythonize(py, &json)?)
1707    }
1708
1709    pub fn settings_type(&self) -> SettingsType {
1710        SettingsType::Anthropic
1711    }
1712}
1713
1714impl AnthropicSettings {
1715    pub fn add_tools(&mut self, tools: Vec<AgentToolDefinition>) -> Result<(), TypeError> {
1716        let current_tools = self.tools.get_or_insert_with(Vec::new);
1717
1718        for tool in tools {
1719            let tool_param = Tool::from_tool_agent_tool_definition(&tool)?;
1720            current_tools.push(tool_param);
1721        }
1722
1723        Ok(())
1724    }
1725}
1726
1727#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1728pub struct AnthropicMessageRequestV1 {
1729    pub model: String,
1730    pub messages: Vec<MessageNum>,
1731    pub system: Vec<MessageNum>,
1732    #[serde(flatten)]
1733    pub settings: AnthropicSettings,
1734    #[serde(skip_serializing_if = "Option::is_none")]
1735    pub output_format: Option<Value>,
1736}
1737
1738pub(crate) fn create_structured_output_schema(json_schema: &Value) -> Value {
1739    serde_json::json!({
1740        "type": "json_schema",
1741        "schema": json_schema,
1742
1743    })
1744}
1745
1746impl RequestAdapter for AnthropicMessageRequestV1 {
1747    fn messages_mut(&mut self) -> &mut Vec<MessageNum> {
1748        &mut self.messages
1749    }
1750    fn messages(&self) -> &[MessageNum] {
1751        &self.messages
1752    }
1753    fn system_instructions(&self) -> Vec<&MessageNum> {
1754        self.system.iter().collect()
1755    }
1756    fn response_json_schema(&self) -> Option<&Value> {
1757        let schema = self.output_format.as_ref();
1758
1759        // if Some, get "schema" field
1760        // output_format: {
1761        //     "type": "json_schema",
1762        //     "schema": { ... }
1763        // }
1764        schema.and_then(|of| of.get("schema"))
1765    }
1766
1767    fn preprend_system_instructions(&mut self, messages: Vec<MessageNum>) -> Result<(), TypeError> {
1768        let mut combined = messages;
1769        combined.append(&mut self.system);
1770        self.system = combined;
1771        Ok(())
1772    }
1773
1774    fn get_py_system_instructions<'py>(
1775        &self,
1776        py: Python<'py>,
1777    ) -> Result<Bound<'py, PyList>, TypeError> {
1778        let py_system_instructions = PyList::empty(py);
1779        for system_msg in &self.system {
1780            py_system_instructions.append(system_msg.to_bound_py_object(py)?)?;
1781        }
1782
1783        Ok(py_system_instructions)
1784    }
1785
1786    fn model_settings<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
1787        let settings = self.settings.clone();
1788        Ok(settings.into_bound_py_any(py)?)
1789    }
1790
1791    fn to_request_body(&self) -> Result<Value, TypeError> {
1792        Ok(serde_json::to_value(self)?)
1793    }
1794    fn match_provider(&self, provider: &Provider) -> bool {
1795        *provider == Provider::Anthropic
1796    }
1797    fn build_provider_enum(
1798        messages: Vec<MessageNum>,
1799        system_instructions: Vec<MessageNum>,
1800        model: String,
1801        settings: ModelSettings,
1802        response_json_schema: Option<Value>,
1803    ) -> Result<ProviderRequest, TypeError> {
1804        let anthropic_settings = match settings {
1805            ModelSettings::AnthropicChat(s) => s,
1806            _ => AnthropicSettings::default(),
1807        };
1808
1809        let output_format =
1810            response_json_schema.map(|json_schema| create_structured_output_schema(&json_schema));
1811
1812        Ok(ProviderRequest::AnthropicV1(AnthropicMessageRequestV1 {
1813            model,
1814            messages,
1815            system: system_instructions,
1816            settings: anthropic_settings,
1817            output_format,
1818        }))
1819    }
1820
1821    fn set_response_json_schema(&mut self, response_json_schema: Option<Value>) {
1822        self.output_format =
1823            response_json_schema.map(|json_schema| create_structured_output_schema(&json_schema));
1824    }
1825
1826    fn add_tools(&mut self, tools: Vec<AgentToolDefinition>) -> Result<(), TypeError> {
1827        self.settings.add_tools(tools)
1828    }
1829}
1830
1831#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1832#[pyclass]
1833pub struct SystemPrompt {
1834    #[pyo3(get)]
1835    #[serde(flatten)]
1836    pub content: Vec<TextBlockParam>,
1837}
1838
1839#[pymethods]
1840impl SystemPrompt {
1841    /// Create a new SystemPrompt
1842    /// Accepts either a single string or a list of TextBlockParams
1843    /// # Arguments
1844    /// * `content` - Either a string or a list of TextBlockParams
1845    /// # Returns
1846    /// * `SystemPrompt` - The created SystemPrompt
1847    /// Errors
1848    /// * `TypeError` - If the content is not a string or a list of TextBlockParams
1849    #[new]
1850    pub fn new(content: &Bound<'_, PyAny>) -> Result<Self, TypeError> {
1851        let content_blocks: Vec<TextBlockParam> =
1852            if content.is_instance_of::<pyo3::types::PyString>() {
1853                let text = content.extract::<String>()?;
1854                let text_block = TextBlockParam::new(text, None, None)?;
1855                vec![text_block]
1856            } else if content.is_instance_of::<pyo3::types::PyList>() {
1857                let content_list = content.extract::<Vec<Bound<'_, PyAny>>>()?;
1858                let mut blocks = Vec::new();
1859                for item in content_list {
1860                    let text_block = item.extract::<TextBlockParam>().map_err(|_| {
1861                        TypeError::InvalidInput(
1862                            "All items in the list must be TextBlockParam".to_string(),
1863                        )
1864                    })?;
1865                    blocks.push(text_block);
1866                }
1867                blocks
1868            } else {
1869                return Err(TypeError::InvalidInput(
1870                    "Content must be either a string or a list of TextBlockParam".to_string(),
1871                ));
1872            };
1873
1874        Ok(Self {
1875            content: content_blocks,
1876        })
1877    }
1878}