openai_utils/
lib.rs

1#![allow(dead_code)]
2#![feature(type_name_of_val)]
3#![feature(associated_type_defaults)]
4#![feature(async_fn_in_trait)]
5
6mod chat_completion;
7mod chat_completion_delta;
8mod chat_completion_request;
9mod error;
10
11use lazy_static::lazy_static;
12#[allow(unused_imports)]
13use log::{debug, error, info, trace, warn};
14use std::{
15    any::type_name_of_val,
16    sync::{Arc, RwLock},
17};
18
19use schemars::{schema_for, JsonSchema};
20use serde_derive::{Deserialize, Serialize};
21use serde_json::Value;
22pub use {
23    chat_completion::ChatCompletion as Chat,
24    chat_completion_delta::ChatCompletionDelta as ChatDelta, chat_completion_delta::DeltaReceiver,
25    chat_completion_request::AiAgent,
26    chat_completion_request::ChatCompletionRequest as ChatRequest,
27};
28
29lazy_static! {
30    static ref OPENAI_API_KEY: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Message {
35    pub role: String,
36
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub content: Option<String>,
39
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub name: Option<String>,
42
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub function_call: Option<FunctionCall>,
45}
46
47impl Message {
48    pub fn new(role: impl Into<String>) -> Self {
49        Self {
50            role: role.into(),
51            content: None,
52            name: None,
53            function_call: None,
54        }
55    }
56
57    pub fn with_content(mut self, content: impl Into<String>) -> Self {
58        self.content = Some(content.into());
59        self
60    }
61
62    pub fn with_name(mut self, name: impl Into<String>) -> Self {
63        self.name = Some(name.into());
64        self
65    }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct FunctionCall {
70    pub name: String,
71    pub arguments: String,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct Function {
76    pub name: String,
77
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub description: Option<String>,
80    pub parameters: Value,
81}
82
83impl Function {
84    pub fn from<FunctionArgs, Func, T>(function: &Func, function_name: &str) -> Self
85    where
86        FunctionArgs: JsonSchema,
87        Func: FnMut(FunctionArgs) -> T,
88    {
89        let schema = schema_for!(FunctionArgs);
90        let parameters = serde_json::to_value(schema)
91            .unwrap_or_else(|_| panic!("Failed to serialize schema for function {}", function_name));
92        Self {
93            name: function_name.to_string(),
94            description: match parameters.get("description") {
95                Some(Value::String(s)) => Some(s.clone()),
96                _ => None,
97            },
98            parameters,
99        }
100    }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct Choice {
105    pub index: i64,
106    pub message: Message,
107    pub finish_reason: String,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ChoiceDelta {
112    pub index: i64,
113    pub delta: Delta,
114
115    #[serde(default)]
116    pub finish_reason: Option<String>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct Delta {
121    pub role: Option<String>,
122
123    pub content: Option<String>,
124
125    pub function_call: Option<FunctionCallDelta>,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct FunctionCallDelta {
130    pub name: Option<String>,
131    pub arguments: Option<String>,
132}
133
134#[derive(Default, Debug, Clone, Serialize, Deserialize)]
135pub struct Usage {
136    pub prompt_tokens: u64,
137    pub completion_tokens: u64,
138    pub total_tokens: u64,
139}
140
141pub fn api_key(api_key: String) {
142    let mut key = OPENAI_API_KEY.write().unwrap();
143    *key = Some(api_key);
144}
145
146#[derive(Default, Debug, Clone, JsonSchema)]
147#[schemars(description = "this function takes no arguments")]
148pub struct NoArgs {
149    _unused: (),
150}
151
152pub fn calculate_message_tokens(message: &Message) -> usize {
153    let bpe = tiktoken_rs::cl100k_base().unwrap();
154
155    bpe.encode_with_special_tokens(&message.content.as_ref().unwrap_or(&"".to_string())).len()
156}
157
158pub fn calculate_tokens(s: &str) -> usize {
159    let bpe = tiktoken_rs::cl100k_base().unwrap();
160
161    bpe.encode_with_special_tokens(s).len()
162}