ragit_api/
audit.rs

1use chrono::{Datelike, DateTime, Local, Utc};
2use crate::Error;
3use ragit_fs::{
4    WriteMode,
5    create_dir_all,
6    exists,
7    parent,
8    read_string,
9    write_string,
10};
11use ragit_pdl::{Message, JsonType};
12use serde_json::{Map, Value};
13use std::collections::hash_map::{Entry, HashMap};
14use std::ops::AddAssign;
15
16#[derive(Clone, Debug)]
17pub struct AuditRecordAt {
18    pub path: String,
19    pub id: String,
20}
21
22#[derive(Clone, Copy, Debug)]
23pub struct AuditRecord {
24    pub input_tokens: u64,
25    pub output_tokens: u64,
26
27    // Divide this by 1 million to get dollars
28    pub input_cost: u64,
29    pub output_cost: u64,
30}
31
32impl AddAssign<AuditRecord> for AuditRecord {
33    fn add_assign(&mut self, rhs: AuditRecord) {
34        self.input_tokens += rhs.input_tokens;
35        self.output_tokens += rhs.output_tokens;
36        self.input_cost += rhs.input_cost;
37        self.output_cost += rhs.output_cost;
38    }
39}
40
41impl From<&AuditRecord> for Value {
42    fn from(r: &AuditRecord) -> Value {
43        Value::Array(vec![
44            Value::from(r.input_tokens),
45            Value::from(r.output_tokens),
46            Value::from(r.input_cost),
47            Value::from(r.output_cost),
48        ])
49    }
50}
51
52impl TryFrom<&Value> for AuditRecord {
53    type Error = Error;
54
55    fn try_from(j: &Value) -> Result<AuditRecord, Error> {
56        let mut result = vec![];
57
58        match &j {
59            Value::Array(arr) => {
60                if arr.len() != 4 {
61                    return Err(Error::WrongSchema(format!("expected an array of length 4, but got length {}", arr.len())));
62                }
63
64                for r in arr.iter() {
65                    match r.as_u64() {
66                        Some(n) => {
67                            result.push(n);
68                        },
69                        None => {
70                            return Err(Error::JsonTypeError {
71                                expected: JsonType::U64,
72                                got: r.into(),
73                            });
74                        },
75                    }
76                }
77
78                Ok(AuditRecord {
79                    input_tokens: result[0],
80                    output_tokens: result[1],
81                    input_cost: result[2],
82                    output_cost: result[3],
83                })
84            },
85            _ => Err(Error::JsonTypeError {
86                expected: JsonType::Array,
87                got: j.into(),
88            }),
89        }
90    }
91}
92
93fn records_from_json(j: &Value) -> Result<HashMap<String, AuditRecord>, Error> {
94    match j {
95        Value::Object(obj) => {
96            let mut result = HashMap::with_capacity(obj.len());
97
98            for (key, value) in obj.iter() {
99                result.insert(key.to_string(), AuditRecord::try_from(value)?);
100            }
101
102            Ok(result)
103        },
104        Value::Array(arr) => {
105            let mut result: HashMap<String, AuditRecord> = HashMap::new();
106
107            for r in arr.iter() {
108                let AuditRecordLegacy {
109                    time,
110                    input,
111                    output,
112                    input_weight,
113                    output_weight,
114                } = AuditRecordLegacy::try_from(r)?;
115                // NOTE: RecordLegacy -> Record conversion might introduce a few hours of errors.
116                let date = match DateTime::<Utc>::from_timestamp(time as i64, 0) {
117                    Some(date) => format!("{:04}{:02}{:02}", date.year(), date.month(), date.day()),
118                    None => format!("19700101"),
119                };
120                let new_record = AuditRecord {
121                    input_tokens: input,
122                    output_tokens: output,
123                    input_cost: input * input_weight / 1000,
124                    output_cost: output * output_weight / 1000,
125                };
126
127                match result.entry(date) {
128                    Entry::Occupied(mut e) => {
129                        *e.get_mut() += new_record;
130                    },
131                    Entry::Vacant(e) => {
132                        e.insert(new_record);
133                    },
134                }
135            }
136
137            Ok(result)
138        },
139        _ => Err(Error::JsonTypeError {
140            expected: JsonType::Object,
141            got: j.into(),
142        }),
143    }
144}
145
146#[derive(Clone)]
147pub struct Tracker(pub HashMap<String, HashMap<String, AuditRecord>>);  // user_name -> usage
148
149impl Tracker {
150    pub fn new() -> Self {
151        Tracker(HashMap::new())
152    }
153
154    pub fn load_from_file(path: &str) -> Result<Self, Error> {
155        let content = read_string(path)?;
156        let j: Value = serde_json::from_str(&content)?;
157        Tracker::try_from(&j)
158    }
159
160    pub fn save_to_file(&self, path: &str) -> Result<(), Error> {
161        Ok(write_string(
162            path,
163            &serde_json::to_string_pretty(&Value::from(self))?,
164            WriteMode::Atomic,
165        )?)
166    }
167}
168
169impl TryFrom<&Value> for Tracker {
170    type Error = Error;
171
172    fn try_from(v: &Value) -> Result<Tracker, Error> {
173        match v {
174            Value::Object(obj) => {
175                let mut result = HashMap::new();
176
177                for (k, v) in obj.iter() {
178                    result.insert(k.to_string(), records_from_json(v)?);
179                }
180
181                Ok(Tracker(result))
182            },
183            _ => Err(Error::JsonTypeError {
184                expected: JsonType::Object,
185                got: v.into(),
186            }),
187        }
188    }
189}
190
191impl From<&Tracker> for Value {
192    fn from(t: &Tracker) -> Value {
193        Value::Object(t.0.iter().map(
194            |(id, records)| (
195                id.to_string(),
196                Value::Object(
197                    records.iter().map(
198                        |(date, record)| (
199                            date.to_string(),
200                            Value::from(record),
201                        )
202                    ).collect::<Map<_, _>>()
203                ),
204            )
205        ).collect())
206    }
207}
208
209pub fn dump_api_usage(
210    at: &AuditRecordAt,
211    input_tokens: u64,
212    output_tokens: u64,
213
214    // dollars per 1 billion tokens
215    input_weight: u64,
216    output_weight: u64,
217
218    // legacy option
219    _clean_up_records: bool,
220) -> Result<(), Error> {
221    let mut tracker = Tracker::load_from_file(&at.path)?;
222    let today = Local::now();
223    let today = format!("{:04}{:02}{:02}", today.year(), today.month(), today.day());
224    let new_record = AuditRecord {
225        input_tokens,
226        output_tokens,
227        input_cost: input_tokens * input_weight / 1000,
228        output_cost: output_tokens * output_weight / 1000,
229    };
230
231    match tracker.0.entry(at.id.to_string()) {
232        Entry::Occupied(mut e) => match e.get_mut().entry(today) {
233            Entry::Occupied(mut e) => {
234                *e.get_mut() += new_record;
235            },
236            Entry::Vacant(e) => {
237                e.insert(new_record);
238            },
239        },
240        Entry::Vacant(e) => {
241            e.insert([(today, new_record)].into_iter().collect());
242        },
243    }
244
245    tracker.save_to_file(&at.path)?;
246    Ok(())
247}
248
249pub fn get_user_usage_data_since(at: AuditRecordAt, since: DateTime<Local>) -> Option<HashMap<String, AuditRecord>> {
250    let since = format!("{:04}{:02}{:02}", since.year(), since.month(), since.day());
251
252    match Tracker::load_from_file(&at.path) {
253        Ok(tracker) => match tracker.0.get(&at.id) {
254            Some(records) => Some(records.iter().filter(
255                |(date, _)| date >= &&since
256            ).map(
257                |(date, record)| (date.to_string(), record.clone())
258            ).collect()),
259            None => None,
260        },
261        _ => None,
262    }
263}
264
265pub fn get_usage_data_since(path: &str, since: DateTime<Local>) -> Option<HashMap<String, AuditRecord>> {
266    let since = format!("{:04}{:02}{:02}", since.year(), since.month(), since.day());
267
268    match Tracker::load_from_file(path) {
269        Ok(tracker) => {
270            let mut result = HashMap::new();
271
272            for records in tracker.0.values() {
273                for (date, record) in records.iter() {
274                    if date >= &since {
275                        result.insert(date.to_string(), record.clone());
276                    }
277                }
278            }
279
280            Some(result)
281        },
282        _ => None,
283    }
284}
285
286/// It returns the cost in dollars (in a formatted string), without any currency unit.
287pub fn calc_usage(records: &HashMap<String, AuditRecord>) -> String {
288    // cost * 1M
289    let mut total: u64 = records.values().map(
290        |AuditRecord { input_cost, output_cost, .. }| *input_cost + *output_cost
291    ).sum();
292
293    // cost * 1K
294    total /= 1000;
295
296    format!("{:.3}", total as f64 / 1_000.0)
297}
298
299pub fn dump_pdl(
300    messages: &[Message],
301    response: &str,
302    reasoning: &Option<String>,
303    path: &str,
304    metadata: String,
305) -> Result<(), Error> {
306    let mut markdown = vec![];
307
308    for message in messages.iter() {
309        markdown.push(format!(
310            "\n\n<|{:?}|>\n\n{}",
311            message.role,
312            message.content.iter().map(|c| c.to_string()).collect::<Vec<String>>().join(""),
313        ));
314    }
315
316    markdown.push(format!(
317        "\n\n<|Assistant|>{}\n\n{response}",
318        if let Some(reasoning) = reasoning {
319            format!("\n\n<|Reasoning|>\n\n{reasoning}\n\n")
320        } else {
321            String::new()
322        },
323    ));
324    markdown.push(format!("{}# {metadata} #{}", '{', '}'));  // tera format
325
326    if let Ok(parent) = parent(path) {
327        if !exists(&parent) {
328            create_dir_all(&parent)?;
329        }
330    }
331
332    write_string(
333        path,
334        &markdown.join("\n"),
335        WriteMode::CreateOrTruncate,
336    )?;
337
338    Ok(())
339}
340
341/*
342 * Below is a previous implementation of `AuditRecord`.
343 * I found it painfully slowing, so I rewrite it from scratch (above).
344 */
345
346impl From<AuditRecordLegacy> for Value {
347    fn from(r: AuditRecordLegacy) -> Value {
348        Value::Array(vec![
349            Value::from(r.time),
350            Value::from(r.input),
351            Value::from(r.output),
352            Value::from(r.input_weight),
353            Value::from(r.output_weight),
354        ])
355    }
356}
357
358#[derive(Clone, Copy, Debug)]
359pub struct AuditRecordLegacy {
360    pub time: u64,
361    pub input: u64,
362    pub output: u64,
363
364    // dollars per 1 billion tokens
365    pub input_weight: u64,
366    pub output_weight: u64,
367}
368
369impl TryFrom<&Value> for AuditRecordLegacy {
370    type Error = Error;
371
372    fn try_from(j: &Value) -> Result<AuditRecordLegacy, Error> {
373        let mut result = vec![];
374
375        match &j {
376            Value::Array(arr) => {
377                if arr.len() != 5 {
378                    return Err(Error::WrongSchema(format!("expected an array of length 5, but got length {}", arr.len())));
379                }
380
381                for r in arr.iter() {
382                    match r.as_u64() {
383                        Some(n) => {
384                            result.push(n);
385                        },
386                        None => {
387                            return Err(Error::JsonTypeError {
388                                expected: JsonType::U64,
389                                got: r.into(),
390                            });
391                        },
392                    }
393                }
394
395                Ok(AuditRecordLegacy {
396                    time: result[0],
397                    input: result[1],
398                    output: result[2],
399                    input_weight: result[3],
400                    output_weight: result[4],
401                })
402            },
403            _ => Err(Error::JsonTypeError {
404                expected: JsonType::Array,
405                got: j.into(),
406            }),
407        }
408    }
409}