1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
use instruct_macros_types::InstructMacro;
use std::{collections::HashMap, vec};

use openai_api_rs::v1::{
    api::Client,
    chat_completion::{self, ChatCompletionRequest, JSONSchemaDefine},
    error::APIError,
};

use instruct_macros_types::{ParameterInfo, StructInfo};

pub struct InstructorClient {
    client: Client,
}

impl InstructorClient {
    pub fn new(client: Client) -> Self {
        Self { client }
    }

    fn get_parameters<T>(parameters: Vec<ParameterInfo>) -> HashMap<String, Box<JSONSchemaDefine>>
    where
        T: for<'de> serde::Deserialize<'de>,
    {
        let mut properties = HashMap::new();

        for param in parameters {
            let schema_type = match param.r#type.as_str() {
                "String" => Some(chat_completion::JSONSchemaType::String),
                "u8" => Some(chat_completion::JSONSchemaType::Number),
                _ => None,
            };

            properties.insert(
                param.name.clone(),
                Box::new(chat_completion::JSONSchemaDefine {
                    schema_type,
                    description: Some(param.comment.clone()),
                    ..Default::default()
                }),
            );
        }

        properties
    }

    fn get_required(parameters: Vec<ParameterInfo>) -> Vec<String> {
        parameters.iter().map(|p| p.name.clone()).collect()
    }

    pub fn chat_completion<T>(
        &self,
        req: ChatCompletionRequest,
        max_retries: u8,
    ) -> Result<T, APIError>
    where
        T: InstructMacro + for<'de> serde::Deserialize<'de>,
    {
        let parsed_model: StructInfo = T::get_info();
        let mut error_message: Option<String> = None;

        for _ in 0..max_retries {
            let mut req = req.clone();

            if let Some(ref error) = error_message {
                let new_message = chat_completion::ChatCompletionMessage {
                    role: chat_completion::MessageRole::user,
                    content: chat_completion::Content::Text(error.clone()),
                    name: None,
                };
                req.messages.push(new_message);
            }

            let result = self._retry_sync::<T>(req.clone(), parsed_model.clone());
            match result {
                Ok(value) => {
                    return Ok(value);
                }
                Err(e) => {
                    error_message =
                        Some(format!("Validation Error: {:?}. Please fix the issue", e));
                    continue;
                }
            }
        }

        panic!("Unable to derive model")
    }

    fn _retry_sync<T>(
        &self,
        req: ChatCompletionRequest,
        parsed_model: StructInfo,
    ) -> Result<T, serde_json::Error>
    where
        T: InstructMacro + for<'de> serde::Deserialize<'de>,
    {
        let properties = Self::get_parameters::<T>(parsed_model.parameters.clone());

        let func_call = chat_completion::Tool {
            r#type: chat_completion::ToolType::Function,
            function: chat_completion::Function {
                name: parsed_model.name,
                description: Some(parsed_model.description),
                parameters: chat_completion::FunctionParameters {
                    schema_type: chat_completion::JSONSchemaType::Object,
                    properties: Some(properties),
                    required: Some(Self::get_required(parsed_model.parameters.clone())),
                },
            },
        };

        let req = req
            .tools(vec![func_call])
            .tool_choice(chat_completion::ToolChoiceType::Auto);

        let result = self.client.chat_completion(req).unwrap();

        match result.choices[0].finish_reason {
            Some(chat_completion::FinishReason::tool_calls) => {
                // TODO: Support more than one tool at some point?
                let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();

                match tool_calls.len() {
                    1 => {
                        let tool_call = &tool_calls[0];
                        let arguments = tool_call.function.arguments.clone().unwrap();

                        return serde_json::from_str(&arguments);
                    }
                    _ => {
                        // TODO: Support multiple tool calls at some point
                        let error_message =
                            format!("Unexpected number of tool calls: {:?}", tool_calls);
                        return Err(serde::de::Error::custom(error_message));
                    }
                }
            }
            _ => panic!("Unexpected finish reason"),
        }
    }
}

pub fn from_openai(client: Client) -> InstructorClient {
    InstructorClient::new(client)
}