ragit_api/
message.rs

1use crate::api_provider::ApiProvider;
2use ragit_pdl::{
3    Message,
4    MessageContent,
5    encode_base64,
6};
7use serde_json::{Map, Value};
8
9pub fn message_content_to_json(message: &MessageContent, api_provider: &ApiProvider) -> Value {
10    match message {
11        MessageContent::String(s) => {
12            let mut content = Map::new();
13
14            if api_provider != &ApiProvider::Google {
15                content.insert(String::from("type"), "text".into());
16            }
17
18            content.insert(String::from("text"), s.to_string().into());
19            content.into()
20        },
21        MessageContent::Image { image_type, bytes } => match api_provider {
22            ApiProvider::Anthropic => {
23                let mut content = Map::new();
24                content.insert(String::from("type"), "image".into());
25
26                let mut source = Map::new();
27                source.insert(String::from("type"), "base64".into());
28                source.insert(String::from("media_type"), image_type.get_media_type().into());
29                source.insert(String::from("data"), encode_base64(bytes).into());
30
31                content.insert(String::from("source"), source.into());
32                content.into()
33            },
34            ApiProvider::Google => {
35                let mut content = Map::new();
36                let mut inline_data = Map::new();
37
38                inline_data.insert(String::from("mime_type"), image_type.get_media_type().into());
39                inline_data.insert(String::from("data"), encode_base64(bytes).into());
40
41                content.insert(String::from("inline_data"), inline_data.into());
42                content.into()
43            },
44            // TODO: does cohere support images?
45            _ => {  // assume the others are all openai-compatible
46                let mut content = Map::new();
47                content.insert(String::from("type"), "image_url".into());
48
49                let mut image_url = Map::new();
50                image_url.insert(String::from("url"), format!("data:{};base64,{}", image_type.get_media_type(), encode_base64(bytes)).into());
51                content.insert(String::from("image_url"), image_url.into());
52                content.into()
53            },
54        },
55    }
56}
57
58pub fn message_contents_to_json_array(contents: &[MessageContent], api_provider: &ApiProvider) -> Value {
59    match api_provider {
60        ApiProvider::Google => Value::Array(contents.iter().map(
61            |content| message_content_to_json(content, api_provider)
62        ).collect()),
63        _ => {
64            if contents.len() == 1 && contents[0].is_string() {
65                Value::String(contents[0].unwrap_str().into())
66            }
67
68            else {
69                Value::Array(contents.iter().map(
70                    |content| message_content_to_json(content, api_provider)
71                ).collect())
72            }
73        },
74    }
75}
76
77pub fn message_to_json(message: &Message, api_provider: &ApiProvider) -> Value {
78    let mut result = Map::new();
79    result.insert(String::from("role"), message.role.to_api_string(matches!(api_provider, ApiProvider::Google)).into());
80
81    match (api_provider, message.content.len()) {
82        (_, 0) => panic!("a message without any content"),
83        (ApiProvider::Google, _) => {
84            result.insert(String::from("parts"), message_contents_to_json_array(&message.content, api_provider));
85        },
86        (ApiProvider::Anthropic, 1) if matches!(&message.content[0], MessageContent::String(_)) => match &message.content[0] {
87            MessageContent::String(s) => {
88                result.insert(String::from("content"), s.to_string().into());
89            },
90            MessageContent::Image { .. } => unreachable!(),
91        },
92        (ApiProvider::Anthropic | ApiProvider::Cohere | ApiProvider::OpenAi { .. }, _) => {
93            result.insert(String::from("content"), message_contents_to_json_array(&message.content, api_provider));
94        },
95        (ApiProvider::Test(_), _) => unreachable!(),
96    }
97
98    result.into()
99}