1use std::{
2 fmt::{Debug, Display},
3 ops::Deref,
4 path::{Path, PathBuf},
5 str::FromStr,
6 sync::{
7 Arc,
8 atomic::{AtomicU64, Ordering},
9 },
10 time::Duration,
11};
12
13use async_openai::{
14 Client,
15 config::{AzureConfig, OpenAIConfig},
16 error::OpenAIError,
17 types::{
18 ChatCompletionRequestAssistantMessageContent,
19 ChatCompletionRequestAssistantMessageContentPart,
20 ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestMessage,
21 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
22 ChatCompletionRequestSystemMessageContentPart, ChatCompletionRequestToolMessageContent,
23 ChatCompletionRequestToolMessageContentPart, ChatCompletionRequestUserMessageArgs,
24 ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
25 ChatCompletionResponseMessage, ChatCompletionToolChoiceOption, CreateChatCompletionRequest,
26 CreateChatCompletionRequestArgs, CreateChatCompletionResponse,
27 },
28};
29use clap::Args;
30use color_eyre::{
31 Result,
32 eyre::{OptionExt, eyre},
33};
34use itertools::Itertools;
35use log::{debug, info, trace, warn};
36use serde::{Deserialize, Serialize};
37use tokio::{io::AsyncWriteExt, sync::RwLock};
38
39use crate::{OpenAIModel, error::PromptError};
40
41#[derive(Debug, Clone)]
43pub struct LLMToolChoice(pub ChatCompletionToolChoiceOption);
44
45impl FromStr for LLMToolChoice {
46 type Err = PromptError;
47 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
48 Ok(match s {
49 "auto" => Self(ChatCompletionToolChoiceOption::Auto),
50 "required" => Self(ChatCompletionToolChoiceOption::Required),
51 "none" => Self(ChatCompletionToolChoiceOption::None),
52 _ => Self(ChatCompletionToolChoiceOption::Named(s.into())),
53 })
54 }
55}
56
57#[derive(Args, Clone, Debug)]
58pub struct LLMSettings {
59 #[arg(long, env = "LLM_TEMPERATURE", default_value_t = 0.8)]
60 pub llm_temperature: f32,
61
62 #[arg(long, env = "LLM_PRESENCE_PENALTY", default_value_t = 0.0)]
63 pub llm_presence_penalty: f32,
64
65 #[arg(long, env = "LLM_PROMPT_TIMEOUT", default_value_t = 120)]
66 pub llm_prompt_timeout: u64,
67
68 #[arg(long, env = "LLM_RETRY", default_value_t = 5)]
69 pub llm_retry: u64,
70
71 #[arg(long, env = "LLM_MAX_COMPLETION_TOKENS", default_value_t = 16384)]
72 pub llm_max_completion_tokens: u32,
73
74 #[arg(long, env = "LLM_TOOL_CHOINCE", default_value = "auto")]
75 pub llm_tool_choice: ChatCompletionToolChoiceOption,
76}
77
78#[derive(Args, Clone, Debug)]
79pub struct OpenAISetup {
80 #[arg(
81 long,
82 env = "OPENAI_API_URL",
83 default_value = "https://api.openai.com/v1"
84 )]
85 pub openai_url: String,
86
87 #[arg(long, env = "OPENAI_API_KEY")]
88 pub openai_key: Option<String>,
89
90 #[arg(long, env = "OPENAI_API_ENDPOINT")]
91 pub openai_endpoint: Option<String>,
92
93 #[arg(long, default_value_t = 10.0, env = "OPENAI_BILLING_CAP")]
94 pub biling_cap: f64,
95
96 #[arg(long, env = "OPENAI_API_MODEL", default_value = "o1")]
97 pub model: OpenAIModel,
98
99 #[arg(long, env = "LLM_DEBUG")]
100 pub llm_debug: Option<PathBuf>,
101
102 #[clap(flatten)]
103 pub llm_settings: LLMSettings,
104}
105
106impl OpenAISetup {
107 pub fn to_config(&self) -> SupportedConfig {
108 if let Some(ep) = self.openai_endpoint.as_ref() {
109 let cfg = AzureConfig::new()
110 .with_api_base(&self.openai_url)
111 .with_api_key(self.openai_key.clone().unwrap_or_default())
112 .with_deployment_id(ep);
113 SupportedConfig::Azure(cfg)
114 } else {
115 let cfg = OpenAIConfig::new()
116 .with_api_base(&self.openai_url)
117 .with_api_key(self.openai_key.clone().unwrap_or_default());
118 SupportedConfig::OpenAI(cfg)
119 }
120 }
121
122 pub fn to_llm(&self) -> LLM {
123 let billing = RwLock::new(ModelBilling::new(self.biling_cap));
124
125 let debug_path = if let Some(dbg) = self.llm_debug.as_ref() {
126 let pid = std::process::id();
127
128 let mut cnt = 0u64;
129 let debug_path;
130 loop {
131 let test_path = dbg.join(format!("{}-{}", pid, cnt));
132 if !test_path.exists() {
133 std::fs::create_dir_all(&test_path).expect("Fail to create llm debug path?");
134 debug_path = Some(test_path);
135 debug!("The path to save LLM interactions is {:?}", &debug_path);
136 break;
137 } else {
138 cnt += 1;
139 }
140 }
141 debug_path
142 } else {
143 None
144 };
145
146 LLM {
147 llm: Arc::new(LLMInner {
148 client: LLMClient::new(self.to_config()),
149 model: self.model.clone(),
150 billing,
151 llm_debug: debug_path,
152 llm_debug_index: AtomicU64::new(0),
153 default_settings: self.llm_settings.clone(),
154 }),
155 }
156 }
157}
158
159#[derive(Debug, Clone)]
160pub enum SupportedConfig {
161 Azure(AzureConfig),
162 OpenAI(OpenAIConfig),
163}
164
165#[derive(Debug, Clone)]
166pub enum LLMClient {
167 Azure(Client<AzureConfig>),
168 OpenAI(Client<OpenAIConfig>),
169}
170
171impl LLMClient {
172 pub fn new(config: SupportedConfig) -> Self {
173 match config {
174 SupportedConfig::Azure(cfg) => Self::Azure(Client::with_config(cfg)),
175 SupportedConfig::OpenAI(cfg) => Self::OpenAI(Client::with_config(cfg)),
176 }
177 }
178
179 pub async fn create_chat(
180 &self,
181 req: CreateChatCompletionRequest,
182 ) -> Result<CreateChatCompletionResponse, OpenAIError> {
183 match self {
184 Self::Azure(cl) => cl.chat().create(req).await,
185 Self::OpenAI(cl) => cl.chat().create(req).await,
186 }
187 }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct ModelBilling {
192 pub current: f64,
193 pub cap: f64,
194}
195
196impl Display for ModelBilling {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 f.write_fmt(format_args!("Billing({}/{})", self.current, self.cap))
199 }
200}
201
202impl ModelBilling {
203 pub fn new(cap: f64) -> Self {
204 Self { current: 0.0, cap }
205 }
206
207 pub fn in_cap(&self) -> bool {
208 self.current <= self.cap
209 }
210
211 pub fn input_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
212 let pricing = model.pricing();
213
214 self.current += (pricing.input_tokens * (count as f64)) / 1e6;
215
216 if self.in_cap() {
217 Ok(())
218 } else {
219 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
220 }
221 }
222
223 pub fn output_tokens(&mut self, model: &OpenAIModel, count: u64) -> Result<()> {
224 let pricing = model.pricing();
225
226 self.current += pricing.output_tokens * (count as f64) / 1e6;
227
228 if self.in_cap() {
229 Ok(())
230 } else {
231 Err(eyre!("cap {} reached, current {}", self.cap, self.current))
232 }
233 }
234}
235
236#[derive(Debug, Clone)]
237pub struct LLM {
238 pub llm: Arc<LLMInner>,
239}
240
241impl Deref for LLM {
242 type Target = LLMInner;
243
244 fn deref(&self) -> &Self::Target {
245 &self.llm
246 }
247}
248
249#[derive(Debug)]
250pub struct LLMInner {
251 pub client: LLMClient,
252 pub model: OpenAIModel,
253 pub billing: RwLock<ModelBilling>,
254 pub llm_debug: Option<PathBuf>,
255 pub llm_debug_index: AtomicU64,
256 pub default_settings: LLMSettings,
257}
258
259pub fn completion_to_role(msg: &ChatCompletionRequestMessage) -> &'static str {
260 match msg {
261 ChatCompletionRequestMessage::Assistant(_) => "ASSISTANT",
262 ChatCompletionRequestMessage::Developer(_) => "DEVELOPER",
263 ChatCompletionRequestMessage::Function(_) => "FUNCTION",
264 ChatCompletionRequestMessage::System(_) => "SYSTEM",
265 ChatCompletionRequestMessage::Tool(_) => "TOOL",
266 ChatCompletionRequestMessage::User(_) => "USER",
267 }
268}
269
270pub fn response_to_string(resp: &ChatCompletionResponseMessage) -> String {
271 let mut s = String::new();
272 if let Some(content) = resp.content.as_ref() {
273 s += content;
274 s += "\n";
275 }
276
277 if let Some(tools) = resp.tool_calls.as_ref() {
278 s += &tools
279 .iter()
280 .map(|t| {
281 format!(
282 "<toolcall name=\"{}\">\n{}\n</toolcall>",
283 &t.function.name, &t.function.arguments
284 )
285 })
286 .join("\n");
287 }
288
289 if let Some(refusal) = &resp.refusal {
290 s += refusal;
291 s += "\n";
292 }
293
294 let role = resp.role.to_string().to_uppercase();
295
296 format!("<{}>\n{}\n</{}>\n", &role, s, &role)
297}
298
299pub fn completion_to_string(msg: &ChatCompletionRequestMessage) -> String {
300 const CONT: &str = "<cont/>\n";
301 const NONE: &str = "<none/>\n";
302 let role = completion_to_role(msg);
303 let content = match msg {
304 ChatCompletionRequestMessage::Assistant(ass) => {
305 let msg = ass
306 .content
307 .as_ref()
308 .map(|ass| match ass {
309 ChatCompletionRequestAssistantMessageContent::Text(s) => s.clone(),
310 ChatCompletionRequestAssistantMessageContent::Array(arr) => arr
311 .iter()
312 .map(|v| match v {
313 ChatCompletionRequestAssistantMessageContentPart::Text(s) => {
314 s.text.clone()
315 }
316 ChatCompletionRequestAssistantMessageContentPart::Refusal(rf) => {
317 rf.refusal.clone()
318 }
319 })
320 .join(CONT),
321 })
322 .unwrap_or(NONE.to_string());
323 let tool_calls = ass
324 .tool_calls
325 .iter()
326 .flatten()
327 .map(|t| {
328 format!(
329 "<toolcall name=\"{}\">{}</toolcall>",
330 &t.function.name, &t.function.arguments
331 )
332 })
333 .join("\n");
334 format!("{}\n{}", msg, tool_calls)
335 }
336 ChatCompletionRequestMessage::Developer(dev) => match &dev.content {
337 ChatCompletionRequestDeveloperMessageContent::Text(t) => t.clone(),
338 ChatCompletionRequestDeveloperMessageContent::Array(arr) => {
339 arr.iter().map(|v| v.text.clone()).join(CONT)
340 }
341 },
342 ChatCompletionRequestMessage::Function(f) => f.content.clone().unwrap_or(NONE.to_string()),
343 ChatCompletionRequestMessage::System(sys) => match &sys.content {
344 ChatCompletionRequestSystemMessageContent::Text(t) => t.clone(),
345 ChatCompletionRequestSystemMessageContent::Array(arr) => arr
346 .iter()
347 .map(|v| match v {
348 ChatCompletionRequestSystemMessageContentPart::Text(t) => t.text.clone(),
349 })
350 .join(CONT),
351 },
352 ChatCompletionRequestMessage::Tool(tool) => match &tool.content {
353 ChatCompletionRequestToolMessageContent::Text(t) => t.clone(),
354 ChatCompletionRequestToolMessageContent::Array(arr) => arr
355 .iter()
356 .map(|v| match v {
357 ChatCompletionRequestToolMessageContentPart::Text(t) => t.text.clone(),
358 })
359 .join(CONT),
360 },
361 ChatCompletionRequestMessage::User(usr) => match &usr.content {
362 ChatCompletionRequestUserMessageContent::Text(t) => t.clone(),
363 ChatCompletionRequestUserMessageContent::Array(arr) => arr
364 .iter()
365 .map(|v| match v {
366 ChatCompletionRequestUserMessageContentPart::Text(t) => t.text.clone(),
367 ChatCompletionRequestUserMessageContentPart::ImageUrl(img) => {
368 format!("<img url=\"{}\"/>", &img.image_url.url)
369 }
370 ChatCompletionRequestUserMessageContentPart::InputAudio(audio) => {
371 format!("<audio>{}</audio>", audio.input_audio.data)
372 }
373 })
374 .join(CONT),
375 },
376 };
377
378 format!("<{}>\n{}\n</{}>\n", role, content, role)
379}
380
381impl LLMInner {
382 async fn rewrite_json<T: Serialize + Debug>(fpath: &Path, t: &T) -> Result<(), PromptError> {
383 let mut json_fp = fpath.to_path_buf();
384 json_fp.set_file_name(format!(
385 "{}.json",
386 json_fp
387 .file_stem()
388 .ok_or_eyre(eyre!("no filename"))?
389 .to_str()
390 .ok_or_eyre(eyre!("non-utf fname"))?
391 ));
392
393 let mut fp = tokio::fs::OpenOptions::new()
394 .create(true)
395 .append(true)
396 .write(true)
397 .open(&json_fp)
398 .await?;
399 let s = match serde_json::to_string(&t) {
400 Ok(s) => s,
401 Err(_) => format!("{:?}", &t),
402 };
403 fp.write_all(s.as_bytes()).await?;
404 fp.write_all(b"\n").await?;
405 fp.flush().await?;
406
407 Ok(())
408 }
409
410 async fn save_llm_user(
411 fpath: &PathBuf,
412 user_msg: &CreateChatCompletionRequest,
413 ) -> Result<(), PromptError> {
414 let mut fp = tokio::fs::OpenOptions::new()
415 .create(true)
416 .truncate(true)
417 .write(true)
418 .open(&fpath)
419 .await?;
420 fp.write_all(b"<Request>\n").await?;
421 for it in user_msg.messages.iter() {
422 let msg = completion_to_string(it);
423 fp.write_all(msg.as_bytes()).await?;
424 }
425
426 let mut tools = vec![];
427 for tool in user_msg
428 .tools
429 .as_ref()
430 .map(|t: &Vec<async_openai::types::ChatCompletionTool>| t.iter())
431 .into_iter()
432 .flatten()
433 {
434 tools.push(format!(
435 "<tool name=\"{}\", description=\"{}\", strict={}>\n{}\n</tool>",
436 &tool.function.name,
437 &tool.function.description.clone().unwrap_or_default(),
438 tool.function.strict.unwrap_or_default(),
439 tool.function
440 .parameters
441 .as_ref()
442 .map(serde_json::to_string_pretty)
443 .transpose()?
444 .unwrap_or_default()
445 ));
446 }
447 fp.write_all(tools.join("\n").as_bytes()).await?;
448 fp.write_all(b"\n</Request>\n").await?;
449 fp.flush().await?;
450
451 Self::rewrite_json(fpath, user_msg).await?;
452
453 Ok(())
454 }
455
456 async fn save_llm_resp(fpath: &PathBuf, resp: &CreateChatCompletionResponse) -> Result<()> {
457 let mut fp = tokio::fs::OpenOptions::new()
458 .create(false)
459 .append(true)
460 .write(true)
461 .open(&fpath)
462 .await?;
463 fp.write_all(b"<Response>\n").await?;
464 for it in &resp.choices {
465 let msg = response_to_string(&it.message);
466 fp.write_all(msg.as_bytes()).await?;
467 }
468 fp.write_all(b"\n</Response>\n").await?;
469 fp.flush().await?;
470
471 Self::rewrite_json(fpath, resp).await?;
472
473 Ok(())
474 }
475
476 fn on_llm_debug(&self, prefix: &str) -> Option<PathBuf> {
477 if let Some(output_folder) = self.llm_debug.as_ref() {
478 let idx = self.llm_debug_index.fetch_add(1, Ordering::SeqCst);
479 let fpath = output_folder.join(format!("{}-{:0>12}.xml", prefix, idx));
480 Some(fpath)
481 } else {
482 None
483 }
484 }
485
486 pub async fn prompt_once_with_retry(
488 &self,
489 sys_msg: &str,
490 user_msg: &str,
491 prefix: Option<&str>,
492 settings: Option<LLMSettings>,
493 ) -> Result<CreateChatCompletionResponse, PromptError> {
494 let settings = settings.unwrap_or(self.default_settings.clone());
495 let sys = ChatCompletionRequestSystemMessageArgs::default()
496 .content(sys_msg)
497 .build()?;
498
499 let user = ChatCompletionRequestUserMessageArgs::default()
500 .content(user_msg)
501 .build()?;
502 let req = CreateChatCompletionRequestArgs::default()
503 .messages(vec![sys.into(), user.into()])
504 .model(self.model.to_string())
505 .temperature(settings.llm_temperature)
506 .presence_penalty(settings.llm_presence_penalty)
507 .max_completion_tokens(settings.llm_max_completion_tokens)
508 .tool_choice(settings.llm_tool_choice)
509 .build()?;
510
511 let timeout = if settings.llm_prompt_timeout == 0 {
512 Duration::MAX
513 } else {
514 Duration::from_secs(settings.llm_prompt_timeout)
515 };
516
517 self.complete_once_with_retry(&req, prefix, Some(timeout), Some(settings.llm_retry))
518 .await
519 }
520
521 pub async fn complete_once_with_retry(
522 &self,
523 req: &CreateChatCompletionRequest,
524 prefix: Option<&str>,
525 timeout: Option<Duration>,
526 retry: Option<u64>,
527 ) -> Result<CreateChatCompletionResponse, PromptError> {
528 let timeout = if let Some(timeout) = timeout {
529 timeout
530 } else {
531 Duration::MAX
532 };
533
534 let retry = if let Some(retry) = retry {
535 retry
536 } else {
537 u64::MAX
538 };
539
540 let mut last = None;
541 for idx in 0..retry {
542 match tokio::time::timeout(timeout, self.complete(req.clone(), prefix)).await {
543 Ok(r) => {
544 last = Some(r);
545 }
546 Err(_) => {
547 warn!("Timeout with {} retry, timeout = {:?}", idx, timeout);
548 continue;
549 }
550 };
551
552 match last {
553 Some(Ok(r)) => return Ok(r),
554 Some(Err(ref e)) => {
555 warn!(
556 "Having an error {} during {} retry (timeout is {:?})",
557 e, idx, timeout
558 );
559 }
560 _ => {}
561 }
562 }
563
564 last.ok_or_eyre(eyre!("retry is zero?!"))
565 .map_err(PromptError::Other)?
566 }
567
568 pub async fn complete(
569 &self,
570 req: CreateChatCompletionRequest,
571 prefix: Option<&str>,
572 ) -> Result<CreateChatCompletionResponse, PromptError> {
573 let prefix = if let Some(prefix) = prefix {
574 prefix.to_string()
575 } else {
576 "llm".to_string()
577 };
578 let debug_fp = self.on_llm_debug(&prefix);
579
580 if let Some(debug_fp) = debug_fp.as_ref() {
581 if let Err(e) = Self::save_llm_user(debug_fp, &req).await {
582 warn!("Fail to save user due to {}", e);
583 }
584 }
585
586 trace!("Sending completion request: {:?}", &req);
587 let resp = self.client.create_chat(req).await?;
588
589 if let Some(debug_fp) = debug_fp.as_ref() {
590 if let Err(e) = Self::save_llm_resp(debug_fp, &resp).await {
591 warn!("Fail to save resp due to {}", e);
592 }
593 }
594
595 if let Some(usage) = &resp.usage {
596 self.billing
597 .write()
598 .await
599 .input_tokens(&self.model, usage.prompt_tokens as u64)
600 .map_err(PromptError::Other)?;
601 self.billing
602 .write()
603 .await
604 .output_tokens(&self.model, usage.completion_tokens as u64)
605 .map_err(PromptError::Other)?;
606 } else {
607 warn!("No usage?!")
608 }
609
610 info!("Model Billing: {}", &self.billing.read().await);
611 Ok(resp)
612 }
613
614 pub async fn prompt_once(
615 &self,
616 sys_msg: &str,
617 user_msg: &str,
618 prefix: Option<&str>,
619 settings: Option<LLMSettings>,
620 ) -> Result<CreateChatCompletionResponse, PromptError> {
621 let settings = settings.unwrap_or(self.default_settings.clone());
622 let sys = ChatCompletionRequestSystemMessageArgs::default()
623 .content(sys_msg)
624 .build()?;
625
626 let user = ChatCompletionRequestUserMessageArgs::default()
627 .content(user_msg)
628 .build()?;
629 let req = CreateChatCompletionRequestArgs::default()
630 .messages(vec![sys.into(), user.into()])
631 .model(self.model.to_string())
632 .temperature(settings.llm_temperature)
633 .presence_penalty(settings.llm_presence_penalty)
634 .max_completion_tokens(settings.llm_max_completion_tokens)
635 .build()?;
636 self.complete(req, prefix).await
637 }
638}