openai_models/
llm.rs

1use std::{
2    fmt::{Debug, Display},
3    ops::{Deref, DerefMut},
4    path::{Path, PathBuf},
5    str::FromStr,
6    sync::{
7        Arc,
8        atomic::{AtomicU64, Ordering},
9    },
10    time::Duration,
11};
12
13use async_openai::{
14    Client,
15    config::{AzureConfig, OpenAIConfig},
16    error::OpenAIError,
17    types::chat::{
18        ChatCompletionMessageToolCalls, ChatCompletionNamedToolChoiceCustom,
19        ChatCompletionRequestAssistantMessageContent,
20        ChatCompletionRequestAssistantMessageContentPart,
21        ChatCompletionRequestDeveloperMessageContent,
22        ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
23        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
24        ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
25        ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
26        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
27        ChatCompletionResponseMessage, ChatCompletionToolChoiceOption, ChatCompletionTools,
28        CreateChatCompletionRequest, CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
29        CustomName, ToolChoiceOptions,
30    },
31};
32use clap::Args;
33use color_eyre::{
34    Result,
35    eyre::{OptionExt, eyre},
36};
37use itertools::Itertools;
38use log::{debug, info, trace, warn};
39use serde::{Deserialize, Serialize};
40use tokio::{io::AsyncWriteExt, sync::RwLock};
41
42use crate::{OpenAIModel, error::PromptError};
43
44// Upstream implementation is flawed
45#[derive(Debug, Clone)]
46pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
47
48impl FromStr for LLMToolChoice {
49    type Err = PromptError;
50    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
51        Ok(match s {
52            "auto" => Self(ChatCompletionToolChoiceOption::Mode(
53                ToolChoiceOptions::Auto,
54            )),
55            "required" => Self(ChatCompletionToolChoiceOption::Mode(
56                ToolChoiceOptions::Required,
57            )),
58            "none" => Self(ChatCompletionToolChoiceOption::Mode(
59                ToolChoiceOptions::None,
60            )),
61            _ => Self(ChatCompletionToolChoiceOption::Custom(
62                ChatCompletionNamedToolChoiceCustom {
63                    custom: CustomName {
64                        name: s.to_string(),
65                    },
66                },
67            )),
68        })
69    }
70}
71
72impl Deref for LLMToolChoice {
73    type Target = ChatCompletionToolChoiceOption;
74    fn deref(&self) -> &Self::Target {
75        &self.0
76    }
77}
78
79impl DerefMut for LLMToolChoice {
80    fn deref_mut(&mut self) -> &mut Self::Target {
81        &mut self.0
82    }
83}
84
85impl From<ChatCompletionToolChoiceOption> for LLMToolChoice {
86    fn from(value: ChatCompletionToolChoiceOption) -> Self {
87        Self(value)
88    }
89}
90
91impl From<LLMToolChoice> for ChatCompletionToolChoiceOption {
92    fn from(value: LLMToolChoice) -> Self {
93        value.0
94    }
95}
96
97#[derive(Args, Clone, Debug)]
98pub struct LLMSettings {
99    #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
100    pub llm_temperature: f32,
101
102    #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
103    pub llm_presence_penalty: f32,
104
105    #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
106    pub llm_prompt_timeout: u64,
107
108    #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
109    pub llm_retry: u64,
110
111    #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
112    pub llm_max_completion_tokens: u32,
113
114    #[arg(long, env = "LLM_TOOL_CHOINCE")]
115    pub llm_tool_choice: Option<LLMToolChoice>,
116}
117
118#[derive(Args, Clone, Debug)]
119pub struct OpenAISetup {
120    #[arg(
121        long,
122        env = "OPENAI_API_URL",
123        default_value = "https://api.openai.com/v1"
124    )]
125    pub openai_url: String,
126
127    #[arg(long, env = "AZURE_OPENAI_ENDPOINT")]
128    pub azure_openai_endpoint: Option<String>,
129
130    #[arg(long, env = "OPENAI_API_KEY")]
131    pub openai_key: Option<String>,
132
133    #[arg(long, env = "AZURE_API_DEPLOYMENT")]
134    pub azure_deployment: Option<String>,
135
136    #[arg(long, env = "AZURE_API_VERSION", default_value = "2025-01-01-preview")]
137    pub azure_api_version: String,
138
139    #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
140    pub biling_cap: f64,
141
142    #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
143    pub model: OpenAIModel,
144
145    #[arg(long, env = "LLM_DEBUG")]
146    pub llm_debug: Option<PathBuf>,
147
148    #[clap(flatten)]
149    pub llm_settings: LLMSettings,
150}
151
152impl OpenAISetup {
153    pub fn to_config(&self) -> SupportedConfig {
154        if let Some(ep) = self.azure_openai_endpoint.as_ref() {
155            let cfg = AzureConfig::new()
156                .with_api_base(ep)
157                .with_api_key(self.openai_key.clone().unwrap_or_default())
158                .with_deployment_id(
159                    self.azure_deployment
160                        .as_ref()
161                        .unwrap_or(&self.model.to_string()),
162                )
163                .with_api_version(&self.azure_api_version);
164            SupportedConfig::Azure(cfg)
165        } else {
166            let cfg = OpenAIConfig::new()
167                .with_api_base(&self.openai_url)
168                .with_api_key(self.openai_key.clone().unwrap_or_default());
169            SupportedConfig::OpenAI(cfg)
170        }
171    }
172
173    pub fn to_llm(&self) -> LLM {
174        let billing = RwLock::new(ModelBilling::new(self.biling_cap));
175
176        let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
177            let pid = std::process::id();
178
179            let mut cnt = 0u64;
180            let debug_path;
181            loop {
182                let test_path = dbg.join(format!("{}-{}", pid, cnt));
183                if !test_path.exists() {
184                    std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
185                    debug_path = Some(test_path);
186                    debug!("The path to save LLM interactions is {:?}", &debug_path);
187                    break;
188                } else {
189                    cnt += 1;
190                }
191            }
192            debug_path
193        } else {
194            None
195        };
196
197        LLM {
198            llm: Arc::new(LLMInner {
199                client: LLMClient::new(self.to_config()),
200                model: self.model.clone(),
201                billing,
202                llm_debug: debug_path,
203                llm_debug_index: AtomicU64::new(0),
204                default_settings: self.llm_settings.clone(),
205            }),
206        }
207    }
208}
209
210#[derive(Debug, Clone)]
211pub enum SupportedConfig {
212    Azure(AzureConfig),
213    OpenAI(OpenAIConfig),
214}
215
216#[derive(Debug, Clone)]
217pub enum LLMClient {
218    Azure(Client<AzureConfig>),
219    OpenAI(Client<OpenAIConfig>),
220}
221
222impl LLMClient {
223    pub fn new(config: SupportedConfig) -> Self {
224        match config {
225            SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
226            SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
227        }
228    }
229
230    pub async fn create_chat(
231        &self,
232        req: CreateChatCompletionRequest,
233    ) -> Result<CreateChatCompletionResponse, OpenAIError> {
234        match self {
235            Self::Azure(cl) => cl.chat().create(req).await,
236            Self::OpenAI(cl) => cl.chat().create(req).await,
237        }
238    }
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct ModelBilling {
243    pub current: f64,
244    pub cap: f64,
245}
246
247impl Display for ModelBilling {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
250    }
251}
252
253impl ModelBilling {
254    pub fn new(cap: f64) -> Self {
255        Self { current: 0.0, cap }
256    }
257
258    pub fn in_cap(&self) -> bool {
259        self.current <= self.cap
260    }
261
262    pub fn input_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
263        let pricing = model.pricing();
264
265        self.current += (pricing.input_tokens * (count as f64)) / 1e6;
266
267        if self.in_cap() {
268            Ok(())
269        } else {
270            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
271        }
272    }
273
274    pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
275        let pricing = model.pricing();
276
277        self.current += pricing.output_tokens * (count as f64) / 1e6;
278
279        if self.in_cap() {
280            Ok(())
281        } else {
282            Err(eyre!("cap {} reached, current {}", self.cap, self.current))
283        }
284    }
285}
286
287#[derive(Debug, Clone)]
288pub struct LLM {
289    pub llm: Arc<LLMInner>,
290}
291
292impl Deref for LLM {
293    type Target = LLMInner;
294
295    fn deref(&self) -> &Self::Target {
296        &self.llm
297    }
298}
299
300#[derive(Debug)]
301pub struct LLMInner {
302    pub client: LLMClient,
303    pub model: OpenAIModel,
304    pub billing: RwLock<ModelBilling>,
305    pub llm_debug: Option<PathBuf>,
306    pub llm_debug_index: AtomicU64,
307    pub default_settings: LLMSettings,
308}
309
310pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
311    match msg {
312        ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
313        ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
314        ChatCompletionRequestMessage::Function(_) => "FUNCTION",
315        ChatCompletionRequestMessage::System(_) => "SYSTEM",
316        ChatCompletionRequestMessage::Tool(_) => "TOOL",
317        ChatCompletionRequestMessage::User(_) => "USER",
318    }
319}
320
321pub fn toolcall_to_string(t: &ChatCompletionMessageToolCalls) -> String {
322    match t {
323        ChatCompletionMessageToolCalls::Function(t) => {
324            format!(
325                "<toolcall name=\"{}\">\n{}\n</toolcall>",
326                &t.function.name, &t.function.arguments
327            )
328        }
329        ChatCompletionMessageToolCalls::Custom(t) => {
330            format!(
331                "<customtoolcall name=\"{}\">\n{}\n</customtoolcall>",
332                &t.custom_tool.name, &t.custom_tool.input
333            )
334        }
335    }
336}
337
338pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
339    let mut s = String::new();
340    if let Some(content) = resp.content.as_ref() {
341        s += content;
342        s += "\n";
343    }
344
345    if let Some(tools) = resp.tool_calls.as_ref() {
346        s += &tools.iter().map(|t| toolcall_to_string(t)).join("\n");
347    }
348
349    if let Some(refusal) = &resp.refusal {
350        s += refusal;
351        s += "\n";
352    }
353
354    let role = resp.role.to_string().to_uppercase();
355
356    format!("<{}>\n{}\n</{}>\n", &role, s, &role)
357}
358
359pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
360    const CONT: &str = "<cont/>\n";
361    const NONE: &str = "<none/>\n";
362    let role = completion_to_role(msg);
363    let content = match msg {
364        ChatCompletionRequestMessage::Assistant(ass) => {
365            let msg = ass
366                .content
367                .as_ref()
368                .map(|ass| match ass {
369                    ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
370                    ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
371                        .iter()
372                        .map(|v| match v {
373                            ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
374                                s.text.clone()
375                            }
376                            ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
377                                rf.refusal.clone()
378                            }
379                        })
380                        .join(CONT),
381                })
382                .unwrap_or(NONE.to_string());
383            let tool_calls = ass
384                .tool_calls
385                .iter()
386                .flatten()
387                .map(|t| toolcall_to_string(t))
388                .join("\n");
389            format!("{}\n{}", msg, tool_calls)
390        }
391        ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
392            ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
393            ChatCompletionRequestDeveloperMessageContent::Array(arr) => arr
394                .iter()
395                .map(|v| match v {
396                    ChatCompletionRequestDeveloperMessageContentPart::Text(v) => v.text.clone(),
397                })
398                .join(CONT),
399        },
400        ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
401        ChatCompletionRequestMessage::System(sys) => match &sys.content {
402            ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
403            ChatCompletionRequestSystemMessageContent::Array(arr) => arr
404                .iter()
405                .map(|v| match v {
406                    ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
407                })
408                .join(CONT),
409        },
410        ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
411            ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
412            ChatCompletionRequestToolMessageContent::Array(arr) => arr
413                .iter()
414                .map(|v| match v {
415                    ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
416                })
417                .join(CONT),
418        },
419        ChatCompletionRequestMessage::User(usr) => match &usr.content {
420            ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
421            ChatCompletionRequestUserMessageContent::Array(arr) => arr
422                .iter()
423                .map(|v| match v {
424                    ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
425                    ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
426                        format!("<img url=\"{}\"/>", &img.image_url.url)
427                    }
428                    ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
429                        format!("<audio>{}</audio>", audio.input_audio.data)
430                    }
431                    ChatCompletionRequestUserMessageContentPart::File(f) => {
432                        format!("<file>{:?}</file>", f)
433                    }
434                })
435                .join(CONT),
436        },
437    };
438
439    format!("<{}>\n{}\n</{}>\n", role, content, role)
440}
441
442impl LLMInner {
443    async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
444        let mut json_fp = fpath.to_path_buf();
445        json_fp.set_file_name(format!(
446            "{}.json",
447            json_fp
448                .file_stem()
449                .ok_or_eyre(eyre!("no filename"))?
450                .to_str()
451                .ok_or_eyre(eyre!("non-utf fname"))?
452        ));
453
454        let mut fp = tokio::fs::OpenOptions::new()
455            .create(true)
456            .append(true)
457            .write(true)
458            .open(&json_fp)
459            .await?;
460        let s = match serde_json::to_string(&t) {
461            Ok(s) => s,
462            Err(_) => format!("{:?}", &t),
463        };
464        fp.write_all(s.as_bytes()).await?;
465        fp.write_all(b"\n").await?;
466        fp.flush().await?;
467
468        Ok(())
469    }
470
471    async fn save_llm_user(
472        fpath: &PathBuf,
473        user_msg: &CreateChatCompletionRequest,
474    ) -> Result<(), PromptError> {
475        let mut fp = tokio::fs::OpenOptions::new()
476            .create(true)
477            .truncate(true)
478            .write(true)
479            .open(&fpath)
480            .await?;
481        fp.write_all(b"=====================\n<Request>\n").await?;
482        for it in user_msg.messages.iter() {
483            let msg = completion_to_string(it);
484            fp.write_all(msg.as_bytes()).await?;
485        }
486
487        let mut tools = vec![];
488        for tool in user_msg
489            .tools
490            .as_ref()
491            .map(|t| t.iter())
492            .into_iter()
493            .flatten()
494        {
495            let s = match tool {
496                ChatCompletionTools::Function(tool) => {
497                    format!(
498                        "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
499                        &tool.function.name,
500                        &tool.function.description.clone().unwrap_or_default(),
501                        tool.function.strict.unwrap_or_default(),
502                        tool.function
503                            .parameters
504                            .as_ref()
505                            .map(serde_json::to_string_pretty)
506                            .transpose()?
507                            .unwrap_or_default()
508                    )
509                }
510                ChatCompletionTools::Custom(tool) => {
511                    format!(
512                        "<customtool name=\"{}\", description=\"{:?}\"></customtool>",
513                        tool.custom.name, tool.custom.description
514                    )
515                }
516            };
517            tools.push(s);
518        }
519        fp.write_all(tools.join("\n").as_bytes()).await?;
520        fp.write_all(b"\n</Request>\n=====================\n")
521            .await?;
522        fp.flush().await?;
523
524        Self::rewrite_json(fpath, user_msg).await?;
525
526        Ok(())
527    }
528
529    async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
530        let mut fp = tokio::fs::OpenOptions::new()
531            .create(false)
532            .append(true)
533            .write(true)
534            .open(&fpath)
535            .await?;
536        fp.write_all(b"=====================\n<Response>\n").await?;
537        for it in &resp.choices {
538            let msg = response_to_string(&it.message);
539            fp.write_all(msg.as_bytes()).await?;
540        }
541        fp.write_all(b"\n</Response>\n=====================\n")
542            .await?;
543        fp.flush().await?;
544
545        Self::rewrite_json(fpath, resp).await?;
546
547        Ok(())
548    }
549
550    fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
551        if let Some(output_folder) = self.llm_debug.as_ref() {
552            let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
553            let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
554            Some(fpath)
555        } else {
556            None
557        }
558    }
559
560    // we use t/s to estimate a timeout to avoid infinite repeating
561    pub async fn prompt_once_with_retry(
562        &self,
563        sys_msg: &str,
564        user_msg: &str,
565        prefix: Option<&str>,
566        settings: Option<LLMSettings>,
567    ) -> Result<CreateChatCompletionResponse, PromptError> {
568        let settings = settings.unwrap_or_else(|| self.default_settings.clone());
569        let sys = ChatCompletionRequestSystemMessageArgs::default()
570            .content(sys_msg)
571            .build()?;
572
573        let user = ChatCompletionRequestUserMessageArgs::default()
574            .content(user_msg)
575            .build()?;
576        let mut req = CreateChatCompletionRequestArgs::default();
577        req.messages(vec![sys.into(), user.into()])
578            .model(self.model.to_string())
579            .temperature(settings.llm_temperature)
580            .presence_penalty(settings.llm_presence_penalty)
581            .max_completion_tokens(settings.llm_max_completion_tokens);
582
583        if let Some(tc) = settings.llm_tool_choice {
584            req.tool_choice(tc);
585        }
586        let req = req.build()?;
587
588        let timeout = if settings.llm_prompt_timeout == 0 {
589            Duration::MAX
590        } else {
591            Duration::from_secs(settings.llm_prompt_timeout)
592        };
593
594        self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
595            .await
596    }
597
598    pub async fn complete_once_with_retry(
599        &self,
600        req: &CreateChatCompletionRequest,
601        prefix: Option<&str>,
602        timeout: Option<Duration>,
603        retry: Option<u64>,
604    ) -> Result<CreateChatCompletionResponse, PromptError> {
605        let timeout = if let Some(timeout) = timeout {
606            timeout
607        } else {
608            Duration::MAX
609        };
610
611        let retry = if let Some(retry) = retry {
612            retry
613        } else {
614            u64::MAX
615        };
616
617        let mut last = None;
618        for idx in 0..retry {
619            match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
620                Ok(r) => {
621                    last = Some(r);
622                }
623                Err(_) => {
624                    warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
625                    continue;
626                }
627            };
628
629            match last {
630                Some(Ok(r)) => return Ok(r),
631                Some(Err(ref e)) => {
632                    warn!(
633                        "Having an error {} during {} retry (timeout is {:?})",
634                        e, idx, timeout
635                    );
636                }
637                _ => {}
638            }
639        }
640
641        last.ok_or_eyre(eyre!("retry is zero?!"))
642            .map_err(PromptError::Other)?
643    }
644
645    pub async fn complete(
646        &self,
647        req: CreateChatCompletionRequest,
648        prefix: Option<&str>,
649    ) -> Result<CreateChatCompletionResponse, PromptError> {
650        let prefix = if let Some(prefix) = prefix {
651            prefix.to_string()
652        } else {
653            "llm".to_string()
654        };
655        let debug_fp = self.on_llm_debug(&prefix);
656
657        if let Some(debug_fp) = debug_fp.as_ref() {
658            if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
659                warn!("Fail to save user due to {}", e);
660            }
661        }
662
663        trace!(
664            "Sending completion request: {:?}",
665            &serde_json::to_string(&req)
666        );
667        let resp = self.client.create_chat(req).await?;
668
669        if let Some(debug_fp) = debug_fp.as_ref() {
670            if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
671                warn!("Fail to save resp due to {}", e);
672            }
673        }
674
675        if let Some(usage) = &resp.usage {
676            self.billing
677                .write()
678                .await
679                .input_tokens(&self.model, usage.prompt_tokens as u64)
680                .map_err(PromptError::Other)?;
681            self.billing
682                .write()
683                .await
684                .output_tokens(&self.model, usage.completion_tokens as u64)
685                .map_err(PromptError::Other)?;
686        } else {
687            warn!("No usage?!")
688        }
689
690        info!("Model Billing: {}", &self.billing.read().await);
691        Ok(resp)
692    }
693
694    pub async fn prompt_once(
695        &self,
696        sys_msg: &str,
697        user_msg: &str,
698        prefix: Option<&str>,
699        settings: Option<LLMSettings>,
700    ) -> Result<CreateChatCompletionResponse, PromptError> {
701        let settings = settings.unwrap_or_else(|| self.default_settings.clone());
702        let sys = ChatCompletionRequestSystemMessageArgs::default()
703            .content(sys_msg)
704            .build()?;
705
706        let user = ChatCompletionRequestUserMessageArgs::default()
707            .content(user_msg)
708            .build()?;
709        let req = CreateChatCompletionRequestArgs::default()
710            .messages(vec![sys.into(), user.into()])
711            .model(self.model.to_string())
712            .temperature(settings.llm_temperature)
713            .presence_penalty(settings.llm_presence_penalty)
714            .max_completion_tokens(settings.llm_max_completion_tokens)
715            .build()?;
716        self.complete(req, prefix).await
717    }
718}