1use std::{collections::VecDeque, str::FromStr};
2
3use derive_more::derive::Display;
4use serde::{Deserialize, Serialize};
5
6pub mod error;
7pub mod llm;
8
9pub mod openai {
10 pub use async_openai::*;
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, Display)]
15pub enum OpenAIModel {
16 #[display("gpt-4o")]
17 GPT4O,
18 #[display("gpt-4o-mini")]
19 GPT4OMINI,
20 #[display("o1")]
21 O1,
22 #[display("o1-mini")]
23 O1MINI,
24 #[display("gpt-3.5-turbo")]
25 GPT35TURBO,
26 #[display("gpt-4")]
27 GPT4,
28 #[display("gpt-4-turbo")]
29 GPT4TURBO,
30 #[display("{_0}")]
31 Other(String, PricingInfo),
32}
33
34impl FromStr for OpenAIModel {
35 type Err = String;
36 fn from_str(s: &str) -> Result<Self, Self::Err> {
37 match s {
38 "gpt-4o" | "gpt4o" => Ok(Self::GPT4O),
39 "gpt-4" | "gpt" => Ok(Self::GPT4),
40 "gpt-4-turbo" | "gpt4turbo" => Ok(Self::GPT4TURBO),
41 "gpt-4o-mini" | "gpt4omini" => Ok(Self::GPT4OMINI),
42 "o1" => Ok(Self::O1),
43 "o1-mini" => Ok(Self::O1MINI),
44 "gpt-3.5-turbo" | "gpt3.5turbo" => Ok(Self::GPT35TURBO),
45 _ => {
46 if !s.contains(",") {
47 return Ok(Self::Other(
48 s.to_string(),
49 PricingInfo {
50 input_tokens: 0.0f64,
51 output_tokens: 0.0f64,
52 cached_input_tokens: None,
53 },
54 ));
55 }
56 let mut tks = s
57 .split(",")
58 .map(|t| t.to_string())
59 .collect::<VecDeque<String>>();
60
61 if tks.len() >= 2 {
62 let model = tks.pop_front().unwrap();
63 let tks = tks
64 .into_iter()
65 .map(|t| f64::from_str(&t))
66 .collect::<Result<Vec<f64>, _>>()
67 .map_err(|e| e.to_string())?;
68
69 let pricing = if tks.len() == 2 {
70 PricingInfo {
71 input_tokens: tks[0],
72 output_tokens: tks[1],
73 cached_input_tokens: None,
74 }
75 } else if tks.len() == 3 {
76 PricingInfo {
77 input_tokens: tks[0],
78 output_tokens: tks[1],
79 cached_input_tokens: Some(tks[2]),
80 }
81 } else {
82 return Err("fail to parse pricing".to_string());
83 };
84
85 Ok(Self::Other(model, pricing))
86 } else {
87 Err("unreconigized model".to_string())
88 }
89 }
90 }
91 }
92}
93
94#[derive(Copy, Debug, Clone, Serialize, Deserialize)]
97pub struct PricingInfo {
98 pub input_tokens: f64,
99 pub output_tokens: f64,
100 pub cached_input_tokens: Option<f64>,
101}
102
103impl FromStr for PricingInfo {
104 type Err = String;
105
106 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
107 let tks = s
108 .split(",")
109 .map(f64::from_str)
110 .collect::<Result<Vec<f64>, _>>()
111 .map_err(|e| e.to_string())?;
112
113 if tks.len() == 2 {
114 Ok(PricingInfo {
115 input_tokens: tks[0],
116 output_tokens: tks[1],
117 cached_input_tokens: None,
118 })
119 } else if tks.len() == 3 {
120 Ok(PricingInfo {
121 input_tokens: tks[0],
122 output_tokens: tks[1],
123 cached_input_tokens: Some(tks[2]),
124 })
125 } else {
126 Err("fail to parse pricing".to_string())
127 }
128 }
129}
130
131impl OpenAIModel {
132 pub fn pricing(&self) -> PricingInfo {
133 match self {
134 Self::GPT4O => PricingInfo {
135 input_tokens: 2.5,
136 output_tokens: 10.00,
137 cached_input_tokens: Some(1.25),
138 },
139 Self::GPT4OMINI => PricingInfo {
140 input_tokens: 0.15,
141 cached_input_tokens: Some(0.075),
142 output_tokens: 0.6,
143 },
144 Self::O1 => PricingInfo {
145 input_tokens: 15.00,
146 cached_input_tokens: Some(7.5),
147 output_tokens: 60.00,
148 },
149 Self::O1MINI => PricingInfo {
150 input_tokens: 3.0,
151 cached_input_tokens: Some(1.5),
152 output_tokens: 12.00,
153 },
154 Self::GPT35TURBO => PricingInfo {
155 input_tokens: 3.0,
156 cached_input_tokens: None,
157 output_tokens: 6.0,
158 },
159 Self::GPT4 => PricingInfo {
160 input_tokens: 30.0,
161 output_tokens: 60.0,
162 cached_input_tokens: None,
163 },
164 Self::GPT4TURBO => PricingInfo {
165 input_tokens: 10.0,
166 output_tokens: 30.0,
167 cached_input_tokens: None,
168 },
169 Self::Other(_, pricing) => *pricing,
170 }
171 }
172
173 pub fn batch_pricing(&self) -> Option<PricingInfo> {
174 match self {
175 Self::GPT4O => Some(PricingInfo {
176 input_tokens: 1.25,
177 output_tokens: 5.00,
178 cached_input_tokens: None,
179 }),
180 Self::GPT4OMINI => Some(PricingInfo {
181 input_tokens: 0.075,
182 output_tokens: 0.3,
183 cached_input_tokens: None,
184 }),
185 _ => None,
186 }
187 }
188}