1#![allow(dead_code)]
2#![feature(type_name_of_val)]
3#![feature(associated_type_defaults)]
4#![feature(async_fn_in_trait)]
5
6mod chat_completion;
7mod chat_completion_delta;
8mod chat_completion_request;
9mod error;
10
11use lazy_static::lazy_static;
12#[allow(unused_imports)]
13use log::{debug, error, info, trace, warn};
14use std::{
15 any::type_name_of_val,
16 sync::{Arc, RwLock},
17};
18
19use schemars::{schema_for, JsonSchema};
20use serde_derive::{Deserialize, Serialize};
21use serde_json::Value;
22pub use {
23 chat_completion::ChatCompletion as Chat,
24 chat_completion_delta::ChatCompletionDelta as ChatDelta, chat_completion_delta::DeltaReceiver,
25 chat_completion_request::AiAgent,
26 chat_completion_request::ChatCompletionRequest as ChatRequest,
27};
28
29lazy_static! {
30 static ref OPENAI_API_KEY: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Message {
35 pub role: String,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub content: Option<String>,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub name: Option<String>,
42
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub function_call: Option<FunctionCall>,
45}
46
47impl Message {
48 pub fn new(role: impl Into<String>) -> Self {
49 Self {
50 role: role.into(),
51 content: None,
52 name: None,
53 function_call: None,
54 }
55 }
56
57 pub fn with_content(mut self, content: impl Into<String>) -> Self {
58 self.content = Some(content.into());
59 self
60 }
61
62 pub fn with_name(mut self, name: impl Into<String>) -> Self {
63 self.name = Some(name.into());
64 self
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct FunctionCall {
70 pub name: String,
71 pub arguments: String,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct Function {
76 pub name: String,
77
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub description: Option<String>,
80 pub parameters: Value,
81}
82
83impl Function {
84 pub fn from<FunctionArgs, Func, T>(function: &Func, function_name: &str) -> Self
85 where
86 FunctionArgs: JsonSchema,
87 Func: FnMut(FunctionArgs) -> T,
88 {
89 let schema = schema_for!(FunctionArgs);
90 let parameters = serde_json::to_value(schema)
91 .unwrap_or_else(|_| panic!("Failed to serialize schema for function {}", function_name));
92 Self {
93 name: function_name.to_string(),
94 description: match parameters.get("description") {
95 Some(Value::String(s)) => Some(s.clone()),
96 _ => None,
97 },
98 parameters,
99 }
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct Choice {
105 pub index: i64,
106 pub message: Message,
107 pub finish_reason: String,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ChoiceDelta {
112 pub index: i64,
113 pub delta: Delta,
114
115 #[serde(default)]
116 pub finish_reason: Option<String>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct Delta {
121 pub role: Option<String>,
122
123 pub content: Option<String>,
124
125 pub function_call: Option<FunctionCallDelta>,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct FunctionCallDelta {
130 pub name: Option<String>,
131 pub arguments: Option<String>,
132}
133
134#[derive(Default, Debug, Clone, Serialize, Deserialize)]
135pub struct Usage {
136 pub prompt_tokens: u64,
137 pub completion_tokens: u64,
138 pub total_tokens: u64,
139}
140
141pub fn api_key(api_key: String) {
142 let mut key = OPENAI_API_KEY.write().unwrap();
143 *key = Some(api_key);
144}
145
146#[derive(Default, Debug, Clone, JsonSchema)]
147#[schemars(description = "this function takes no arguments")]
148pub struct NoArgs {
149 _unused: (),
150}
151
152pub fn calculate_message_tokens(message: &Message) -> usize {
153 let bpe = tiktoken_rs::cl100k_base().unwrap();
154
155 bpe.encode_with_special_tokens(&message.content.as_ref().unwrap_or(&"".to_string())).len()
156}
157
158pub fn calculate_tokens(s: &str) -> usize {
159 let bpe = tiktoken_rs::cl100k_base().unwrap();
160
161 bpe.encode_with_special_tokens(s).len()
162}