openai_models/
lib.rs

1use std::{collections::VecDeque, str::FromStr};
2
3use derive_more::derive::Display;
4use serde::{Deserialize, Serialize};
5
6pub mod error;
7pub mod llm;
8
9pub mod openai {
10    pub use async_openai::*;
11}
12
13// General models, note might alias to a specific model
14#[derive(Debug, Clone, Serialize, Deserialize, Display)]
15pub enum OpenAIModel {
16    #[display("gpt-4o")]
17    GPT4O,
18    #[display("gpt-4o-mini")]
19    GPT4OMINI,
20    #[display("o1")]
21    O1,
22    #[display("o1-mini")]
23    O1MINI,
24    #[display("gpt-3.5-turbo")]
25    GPT35TURBO,
26    #[display("gpt-4")]
27    GPT4,
28    #[display("gpt-4-turbo")]
29    GPT4TURBO,
30    #[display("{_0}")]
31    Other(String, PricingInfo),
32}
33
34impl FromStr for OpenAIModel {
35    type Err = String;
36    fn from_str(s: &str) -> Result<Self, Self::Err> {
37        match s {
38            "gpt-4o" | "gpt4o" => Ok(Self::GPT4O),
39            "gpt-4" | "gpt" => Ok(Self::GPT4),
40            "gpt-4-turbo" | "gpt4turbo" => Ok(Self::GPT4TURBO),
41            "gpt-4o-mini" | "gpt4omini" => Ok(Self::GPT4OMINI),
42            "o1" => Ok(Self::O1),
43            "o1-mini" => Ok(Self::O1MINI),
44            "gpt-3.5-turbo" | "gpt3.5turbo" => Ok(Self::GPT35TURBO),
45            _ => {
46                if !s.contains(",") {
47                    return Ok(Self::Other(
48                        s.to_string(),
49                        PricingInfo {
50                            input_tokens: 0.0f64,
51                            output_tokens: 0.0f64,
52                            cached_input_tokens: None,
53                        },
54                    ));
55                }
56                let mut tks = s
57                    .split(",")
58                    .map(|t| t.to_string())
59                    .collect::<VecDeque<String>>();
60
61                if tks.len() >= 2 {
62                    let model = tks.pop_front().unwrap();
63                    let tks = tks
64                        .into_iter()
65                        .map(|t| f64::from_str(&t))
66                        .collect::<Result<Vec<f64>, _>>()
67                        .map_err(|e| e.to_string())?;
68
69                    let pricing = if tks.len() == 2 {
70                        PricingInfo {
71                            input_tokens: tks[0],
72                            output_tokens: tks[1],
73                            cached_input_tokens: None,
74                        }
75                    } else if tks.len() == 3 {
76                        PricingInfo {
77                            input_tokens: tks[0],
78                            output_tokens: tks[1],
79                            cached_input_tokens: Some(tks[2]),
80                        }
81                    } else {
82                        return Err("fail to parse pricing".to_string());
83                    };
84
85                    Ok(Self::Other(model, pricing))
86                } else {
87                    Err("unreconigized model".to_string())
88                }
89            }
90        }
91    }
92}
93
94// USD per 1M tokens
95// From https://openai.com/api/pricing/
96#[derive(Copy, Debug, Clone, Serialize, Deserialize)]
97pub struct PricingInfo {
98    pub input_tokens: f64,
99    pub output_tokens: f64,
100    pub cached_input_tokens: Option<f64>,
101}
102
103impl FromStr for PricingInfo {
104    type Err = String;
105
106    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
107        let tks = s
108            .split(",")
109            .map(f64::from_str)
110            .collect::<Result<Vec<f64>, _>>()
111            .map_err(|e| e.to_string())?;
112
113        if tks.len() == 2 {
114            Ok(PricingInfo {
115                input_tokens: tks[0],
116                output_tokens: tks[1],
117                cached_input_tokens: None,
118            })
119        } else if tks.len() == 3 {
120            Ok(PricingInfo {
121                input_tokens: tks[0],
122                output_tokens: tks[1],
123                cached_input_tokens: Some(tks[2]),
124            })
125        } else {
126            Err("fail to parse pricing".to_string())
127        }
128    }
129}
130
131impl OpenAIModel {
132    pub fn pricing(&self) -> PricingInfo {
133        match self {
134            Self::GPT4O => PricingInfo {
135                input_tokens: 2.5,
136                output_tokens: 10.00,
137                cached_input_tokens: Some(1.25),
138            },
139            Self::GPT4OMINI => PricingInfo {
140                input_tokens: 0.15,
141                cached_input_tokens: Some(0.075),
142                output_tokens: 0.6,
143            },
144            Self::O1 => PricingInfo {
145                input_tokens: 15.00,
146                cached_input_tokens: Some(7.5),
147                output_tokens: 60.00,
148            },
149            Self::O1MINI => PricingInfo {
150                input_tokens: 3.0,
151                cached_input_tokens: Some(1.5),
152                output_tokens: 12.00,
153            },
154            Self::GPT35TURBO => PricingInfo {
155                input_tokens: 3.0,
156                cached_input_tokens: None,
157                output_tokens: 6.0,
158            },
159            Self::GPT4 => PricingInfo {
160                input_tokens: 30.0,
161                output_tokens: 60.0,
162                cached_input_tokens: None,
163            },
164            Self::GPT4TURBO => PricingInfo {
165                input_tokens: 10.0,
166                output_tokens: 30.0,
167                cached_input_tokens: None,
168            },
169            Self::Other(_, pricing) => *pricing,
170        }
171    }
172
173    pub fn batch_pricing(&self) -> Option<PricingInfo> {
174        match self {
175            Self::GPT4O => Some(PricingInfo {
176                input_tokens: 1.25,
177                output_tokens: 5.00,
178                cached_input_tokens: None,
179            }),
180            Self::GPT4OMINI => Some(PricingInfo {
181                input_tokens: 0.075,
182                output_tokens: 0.3,
183                cached_input_tokens: None,
184            }),
185            _ => None,
186        }
187    }
188}