openai_models/
llm.rs

1use std::{
2    fmt::{Debug, Display},
3    ops::Deref,
4    path::{Path, PathBuf},
5    str::FromStr,
6    sync::{
7        Arc,
8        atomic::{AtomicU64, Ordering},
9    },
10    time::Duration,
11};
12
13use async_openai::{
14    Client,
15    config::{AzureConfig, OpenAIConfig},
16    error::OpenAIError,
17    types::{
18        ChatCompletionRequestAssistantMessageContent,
19        ChatCompletionRequestAssistantMessageContentPart,
20        ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestMessage,
21        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
22        ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
23        ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
24        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
25        ChatCompletionResponseMessage, ChatCompletionToolChoiceOption, CreateChatCompletionRequest,
26        CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
27    },
28};
29use clap::Args;
30use color_eyre::{
31    Result,
32    eyre::{OptionExt, eyre},
33};
34use itertools::Itertools;
35use log::{debug, info, trace, warn};
36use serde::{Deserialize, Serialize};
37use tokio::{io::AsyncWriteExt, sync::RwLock};
38
39use crate::{OpenAIModel, error::PromptError};
40
41// Upstream implementation is flawed
42#[derive(Debug, Clone)]
43pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
44
45impl FromStr for LLMToolChoice {
46    type Err = PromptError;
47    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
48        Ok(match s {
49            "auto" => Self(ChatCompletionToolChoiceOption::Auto),
50            "required" => Self(ChatCompletionToolChoiceOption::Required),
51            "none" => Self(ChatCompletionToolChoiceOption::None),
52            _ => Self(ChatCompletionToolChoiceOption::Named(s.into())),
53        })
54    }
55}
56
57#[derive(Args, Clone, Debug)]
58pub struct LLMSettings {
59    #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
60    pub llm_temperature: f32,
61
62    #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
63    pub llm_presence_penalty: f32,
64
65    #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
66    pub llm_prompt_timeout: u64,
67
68    #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
69    pub llm_retry: u64,
70
71    #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
72    pub llm_max_completion_tokens: u32,
73
74    #[arg(long, env = "LLM_TOOL_CHOINCE", default_value = "auto")]
75    pub llm_tool_choice: ChatCompletionToolChoiceOption,
76}
77
78#[derive(Args, Clone, Debug)]
79pub struct OpenAISetup {
80    #[arg(
81        long,
82        env = "OPENAI_API_URL",
83        default_value = "https://api.openai.com/v1"
84    )]
85    pub openai_url: String,
86
87    #[arg(long, env = "OPENAI_API_KEY")]
88    pub openai_key: Option<String>,
89
90    #[arg(long, env = "OPENAI_API_ENDPOINT")]
91    pub openai_endpoint: Option<String>,
92
93    #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
94    pub biling_cap: f64,
95
96    #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
97    pub model: OpenAIModel,
98
99    #[arg(long, env = "LLM_DEBUG")]
100    pub llm_debug: Option<PathBuf>,
101
102    #[clap(flatten)]
103    pub llm_settings: LLMSettings,
104}
105
106impl OpenAISetup {
107    pub fn to_config(&self) -> SupportedConfig {
108        if let Some(ep) = self.openai_endpoint.as_ref() {
109            let cfg = AzureConfig::new()
110                .with_api_base(&self.openai_url)
111                .with_api_key(self.openai_key.clone().unwrap_or_default())
112                .with_deployment_id(ep);
113            SupportedConfig::Azure(cfg)
114        } else {
115            let cfg = OpenAIConfig::new()
116                .with_api_base(&self.openai_url)
117                .with_api_key(self.openai_key.clone().unwrap_or_default());
118            SupportedConfig::OpenAI(cfg)
119        }
120    }
121
122    pub fn to_llm(&self) -> LLM {
123        let billing = RwLock::new(ModelBilling::new(self.biling_cap));
124
125        let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
126            let pid = std::process::id();
127
128            let mut cnt = 0u64;
129            let debug_path;
130            loop {
131                let test_path = dbg.join(format!("{}-{}", pid, cnt));
132                if !test_path.exists() {
133                    std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
134                    debug_path = Some(test_path);
135                    debug!("The path to save LLM interactions is {:?}", &debug_path);
136                    break;
137                } else {
138                    cnt += 1;
139                }
140            }
141            debug_path
142        } else {
143            None
144        };
145
146        LLM {
147            llm: Arc::new(LLMInner {
148                client: LLMClient::new(self.to_config()),
149                model: self.model.clone(),
150                billing,
151                llm_debug: debug_path,
152                llm_debug_index: AtomicU64::new(0),
153                default_settings: self.llm_settings.clone(),
154            }),
155        }
156    }
157}
158
159#[derive(Debug, Clone)]
160pub enum SupportedConfig {
161    Azure(AzureConfig),
162    OpenAI(OpenAIConfig),
163}
164
165#[derive(Debug, Clone)]
166pub enum LLMClient {
167    Azure(Client<AzureConfig>),
168    OpenAI(Client<OpenAIConfig>),
169}
170
171impl LLMClient {
172    pub fn new(config: SupportedConfig) -> Self {
173        match config {
174            SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
175            SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
176        }
177    }
178
179    pub async fn create_chat(
180        &self,
181        req: CreateChatCompletionRequest,
182    ) -> Result<CreateChatCompletionResponse, OpenAIError> {
183        match self {
184            Self::Azure(cl) => cl.chat().create(req).await,
185            Self::OpenAI(cl) => cl.chat().create(req).await,
186        }
187    }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct ModelBilling {
192    pub current: f64,
193    pub cap: f64,
194}
195
196impl Display for ModelBilling {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
199    }
200}
201
202impl ModelBilling {
203    pub fn new(cap: f64) -> Self {
204        Self { current: 0.0, cap }
205    }
206
207    pub fn in_cap(&self) -> bool {
208        self.current <= self.cap
209    }
210
211    pub fn input_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
212        let pricing = model.pricing();
213
214        self.current += (pricing.input_tokens * (count as f64)) / 1e6;
215
216        if self.in_cap() {
217            Ok(())
218        } else {
219            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
220        }
221    }
222
223    pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
224        let pricing = model.pricing();
225
226        self.current += pricing.output_tokens * (count as f64) / 1e6;
227
228        if self.in_cap() {
229            Ok(())
230        } else {
231            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
232        }
233    }
234}
235
236#[derive(Debug, Clone)]
237pub struct LLM {
238    pub llm: Arc<LLMInner>,
239}
240
241impl Deref for LLM {
242    type Target = LLMInner;
243
244    fn deref(&self) -> &Self::Target {
245        &self.llm
246    }
247}
248
249#[derive(Debug)]
250pub struct LLMInner {
251    pub client: LLMClient,
252    pub model: OpenAIModel,
253    pub billing: RwLock<ModelBilling>,
254    pub llm_debug: Option<PathBuf>,
255    pub llm_debug_index: AtomicU64,
256    pub default_settings: LLMSettings,
257}
258
259pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
260    match msg {
261        ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
262        ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
263        ChatCompletionRequestMessage::Function(_) => "FUNCTION",
264        ChatCompletionRequestMessage::System(_) => "SYSTEM",
265        ChatCompletionRequestMessage::Tool(_) => "TOOL",
266        ChatCompletionRequestMessage::User(_) => "USER",
267    }
268}
269
270pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
271    let mut s = String::new();
272    if let Some(content) = resp.content.as_ref() {
273        s += content;
274        s += "\n";
275    }
276
277    if let Some(tools) = resp.tool_calls.as_ref() {
278        s += &tools
279            .iter()
280            .map(|t| {
281                format!(
282                    "<toolcall name=\"{}\">\n{}\n</toolcall>",
283                    &t.function.name, &t.function.arguments
284                )
285            })
286            .join("\n");
287    }
288
289    if let Some(refusal) = &resp.refusal {
290        s += refusal;
291        s += "\n";
292    }
293
294    let role = resp.role.to_string().to_uppercase();
295
296    format!("<{}>\n{}\n</{}>\n", &role, s, &role)
297}
298
299pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
300    const CONT: &str = "<cont/>\n";
301    const NONE: &str = "<none/>\n";
302    let role = completion_to_role(msg);
303    let content = match msg {
304        ChatCompletionRequestMessage::Assistant(ass) => {
305            let msg = ass
306                .content
307                .as_ref()
308                .map(|ass| match ass {
309                    ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
310                    ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
311                        .iter()
312                        .map(|v| match v {
313                            ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
314                                s.text.clone()
315                            }
316                            ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
317                                rf.refusal.clone()
318                            }
319                        })
320                        .join(CONT),
321                })
322                .unwrap_or(NONE.to_string());
323            let tool_calls = ass
324                .tool_calls
325                .iter()
326                .flatten()
327                .map(|t| {
328                    format!(
329                        "<toolcall name=\"{}\">{}</toolcall>",
330                        &t.function.name, &t.function.arguments
331                    )
332                })
333                .join("\n");
334            format!("{}\n{}", msg, tool_calls)
335        }
336        ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
337            ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
338            ChatCompletionRequestDeveloperMessageContent::Array(arr) => {
339                arr.iter().map(|v| v.text.clone()).join(CONT)
340            }
341        },
342        ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
343        ChatCompletionRequestMessage::System(sys) => match &sys.content {
344            ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
345            ChatCompletionRequestSystemMessageContent::Array(arr) => arr
346                .iter()
347                .map(|v| match v {
348                    ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
349                })
350                .join(CONT),
351        },
352        ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
353            ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
354            ChatCompletionRequestToolMessageContent::Array(arr) => arr
355                .iter()
356                .map(|v| match v {
357                    ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
358                })
359                .join(CONT),
360        },
361        ChatCompletionRequestMessage::User(usr) => match &usr.content {
362            ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
363            ChatCompletionRequestUserMessageContent::Array(arr) => arr
364                .iter()
365                .map(|v| match v {
366                    ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
367                    ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
368                        format!("<img url=\"{}\"/>", &img.image_url.url)
369                    }
370                    ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
371                        format!("<audio>{}</audio>", audio.input_audio.data)
372                    }
373                })
374                .join(CONT),
375        },
376    };
377
378    format!("<{}>\n{}\n</{}>\n", role, content, role)
379}
380
381impl LLMInner {
382    async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
383        let mut json_fp = fpath.to_path_buf();
384        json_fp.set_file_name(format!(
385            "{}.json",
386            json_fp
387                .file_stem()
388                .ok_or_eyre(eyre!("no filename"))?
389                .to_str()
390                .ok_or_eyre(eyre!("non-utf fname"))?
391        ));
392
393        let mut fp = tokio::fs::OpenOptions::new()
394            .create(true)
395            .append(true)
396            .write(true)
397            .open(&json_fp)
398            .await?;
399        let s = match serde_json::to_string(&t) {
400            Ok(s) => s,
401            Err(_) => format!("{:?}", &t),
402        };
403        fp.write_all(s.as_bytes()).await?;
404        fp.write_all(b"\n").await?;
405        fp.flush().await?;
406
407        Ok(())
408    }
409
410    async fn save_llm_user(
411        fpath: &PathBuf,
412        user_msg: &CreateChatCompletionRequest,
413    ) -> Result<(), PromptError> {
414        let mut fp = tokio::fs::OpenOptions::new()
415            .create(true)
416            .truncate(true)
417            .write(true)
418            .open(&fpath)
419            .await?;
420        fp.write_all(b"<Request>\n").await?;
421        for it in user_msg.messages.iter() {
422            let msg = completion_to_string(it);
423            fp.write_all(msg.as_bytes()).await?;
424        }
425
426        let mut tools = vec![];
427        for tool in user_msg
428            .tools
429            .as_ref()
430            .map(|t: &Vec<async_openai::types::ChatCompletionTool>| t.iter())
431            .into_iter()
432            .flatten()
433        {
434            tools.push(format!(
435                "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
436                &tool.function.name,
437                &tool.function.description.clone().unwrap_or_default(),
438                tool.function.strict.unwrap_or_default(),
439                tool.function
440                    .parameters
441                    .as_ref()
442                    .map(serde_json::to_string_pretty)
443                    .transpose()?
444                    .unwrap_or_default()
445            ));
446        }
447        fp.write_all(tools.join("\n").as_bytes()).await?;
448        fp.write_all(b"\n</Request>\n").await?;
449        fp.flush().await?;
450
451        Self::rewrite_json(fpath, user_msg).await?;
452
453        Ok(())
454    }
455
456    async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
457        let mut fp = tokio::fs::OpenOptions::new()
458            .create(false)
459            .append(true)
460            .write(true)
461            .open(&fpath)
462            .await?;
463        fp.write_all(b"<Response>\n").await?;
464        for it in &resp.choices {
465            let msg = response_to_string(&it.message);
466            fp.write_all(msg.as_bytes()).await?;
467        }
468        fp.write_all(b"\n</Response>\n").await?;
469        fp.flush().await?;
470
471        Self::rewrite_json(fpath, resp).await?;
472
473        Ok(())
474    }
475
476    fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
477        if let Some(output_folder) = self.llm_debug.as_ref() {
478            let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
479            let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
480            Some(fpath)
481        } else {
482            None
483        }
484    }
485
486    // we use t/s to estimate a timeout to avoid infinite repeating
487    pub async fn prompt_once_with_retry(
488        &self,
489        sys_msg: &str,
490        user_msg: &str,
491        prefix: Option<&str>,
492        settings: Option<LLMSettings>,
493    ) -> Result<CreateChatCompletionResponse, PromptError> {
494        let settings = settings.unwrap_or(self.default_settings.clone());
495        let sys = ChatCompletionRequestSystemMessageArgs::default()
496            .content(sys_msg)
497            .build()?;
498
499        let user = ChatCompletionRequestUserMessageArgs::default()
500            .content(user_msg)
501            .build()?;
502        let req = CreateChatCompletionRequestArgs::default()
503            .messages(vec![sys.into(), user.into()])
504            .model(self.model.to_string())
505            .temperature(settings.llm_temperature)
506            .presence_penalty(settings.llm_presence_penalty)
507            .max_completion_tokens(settings.llm_max_completion_tokens)
508            .tool_choice(settings.llm_tool_choice)
509            .build()?;
510
511        let timeout = if settings.llm_prompt_timeout == 0 {
512            Duration::MAX
513        } else {
514            Duration::from_secs(settings.llm_prompt_timeout)
515        };
516
517        self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
518            .await
519    }
520
521    pub async fn complete_once_with_retry(
522        &self,
523        req: &CreateChatCompletionRequest,
524        prefix: Option<&str>,
525        timeout: Option<Duration>,
526        retry: Option<u64>,
527    ) -> Result<CreateChatCompletionResponse, PromptError> {
528        let timeout = if let Some(timeout) = timeout {
529            timeout
530        } else {
531            Duration::MAX
532        };
533
534        let retry = if let Some(retry) = retry {
535            retry
536        } else {
537            u64::MAX
538        };
539
540        let mut last = None;
541        for idx in 0..retry {
542            match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
543                Ok(r) => {
544                    last = Some(r);
545                }
546                Err(_) => {
547                    warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
548                    continue;
549                }
550            };
551
552            match last {
553                Some(Ok(r)) => return Ok(r),
554                Some(Err(ref e)) => {
555                    warn!(
556                        "Having an error {} during {} retry (timeout is {:?})",
557                        e, idx, timeout
558                    );
559                }
560                _ => {}
561            }
562        }
563
564        last.ok_or_eyre(eyre!("retry is zero?!"))
565            .map_err(PromptError::Other)?
566    }
567
568    pub async fn complete(
569        &self,
570        req: CreateChatCompletionRequest,
571        prefix: Option<&str>,
572    ) -> Result<CreateChatCompletionResponse, PromptError> {
573        let prefix = if let Some(prefix) = prefix {
574            prefix.to_string()
575        } else {
576            "llm".to_string()
577        };
578        let debug_fp = self.on_llm_debug(&prefix);
579
580        if let Some(debug_fp) = debug_fp.as_ref() {
581            if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
582                warn!("Fail to save user due to {}", e);
583            }
584        }
585
586        trace!("Sending completion request: {:?}", &req);
587        let resp = self.client.create_chat(req).await?;
588
589        if let Some(debug_fp) = debug_fp.as_ref() {
590            if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
591                warn!("Fail to save resp due to {}", e);
592            }
593        }
594
595        if let Some(usage) = &resp.usage {
596            self.billing
597                .write()
598                .await
599                .input_tokens(&self.model, usage.prompt_tokens as u64)
600                .map_err(PromptError::Other)?;
601            self.billing
602                .write()
603                .await
604                .output_tokens(&self.model, usage.completion_tokens as u64)
605                .map_err(PromptError::Other)?;
606        } else {
607            warn!("No usage?!")
608        }
609
610        info!("Model Billing: {}", &self.billing.read().await);
611        Ok(resp)
612    }
613
614    pub async fn prompt_once(
615        &self,
616        sys_msg: &str,
617        user_msg: &str,
618        prefix: Option<&str>,
619        settings: Option<LLMSettings>,
620    ) -> Result<CreateChatCompletionResponse, PromptError> {
621        let settings = settings.unwrap_or(self.default_settings.clone());
622        let sys = ChatCompletionRequestSystemMessageArgs::default()
623            .content(sys_msg)
624            .build()?;
625
626        let user = ChatCompletionRequestUserMessageArgs::default()
627            .content(user_msg)
628            .build()?;
629        let req = CreateChatCompletionRequestArgs::default()
630            .messages(vec![sys.into(), user.into()])
631            .model(self.model.to_string())
632            .temperature(settings.llm_temperature)
633            .presence_penalty(settings.llm_presence_penalty)
634            .max_completion_tokens(settings.llm_max_completion_tokens)
635            .build()?;
636        self.complete(req, prefix).await
637    }
638}