notapsychai/
lib.rs

1use std::fs::OpenOptions;
2use std::io::Write;
3
4use chrono::Local;
5use rustyline::config::EditMode;
6use rustyline::error::ReadlineError;
7use rustyline::hint::HistoryHinter;
8use rustyline::{CompletionType, Config, Editor, EventHandler, KeyEvent};
9use yammer::{GenerateRequest, GenerateResponse, Spinner};
10
11mod cli;
12
13use cli::{CommandHint, ShellHelper, TabEventHandler};
14
15const LAST_SLEPT: &str = "last-slept";
16const SLEPT_HOW_LONG: &str = "slept-how-long";
17const QUALITY_OF_SLEEP: &str = "quality-of-sleep";
18const MEDICATION: &str = "medication";
19const HYGIENE: &str = "hygiene";
20
21/////////////////////////////////////////////// Error //////////////////////////////////////////////
22
23#[derive(Debug)]
24pub enum Error {
25    Internal(String),
26    IO(std::io::Error),
27    Json(serde_json::Error),
28    Reqwest(reqwest::Error),
29    Yammer(yammer::Error),
30}
31
32impl std::fmt::Display for Error {
33    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34        write!(f, "{:#?}", self)
35    }
36}
37
38impl std::error::Error for Error {}
39
40impl From<std::io::Error> for Error {
41    fn from(err: std::io::Error) -> Self {
42        Self::IO(err)
43    }
44}
45
46impl From<serde_json::Error> for Error {
47    fn from(err: serde_json::Error) -> Self {
48        Self::Json(err)
49    }
50}
51
52impl From<reqwest::Error> for Error {
53    fn from(err: reqwest::Error) -> Self {
54        Self::Reqwest(err)
55    }
56}
57
58impl From<yammer::Error> for Error {
59    fn from(err: yammer::Error) -> Self {
60        Self::Yammer(err)
61    }
62}
63
64///////////////////////////////////////////// NotAPsych ////////////////////////////////////////////
65
66pub struct NotAPsych<HELPER: rustyline::Helper, HISTORY: rustyline::history::History> {
67    editor: Editor<HELPER, HISTORY>,
68}
69
70impl<HELPER: rustyline::Helper, HISTORY: rustyline::history::History> NotAPsych<HELPER, HISTORY> {
71    pub async fn checkin(&mut self) {
72        self.sleep_checkin().await;
73        println!();
74        self.medications().await;
75        println!();
76        self.hygiene().await;
77    }
78
79    pub async fn sleep_checkin(&mut self) {
80        self.last_slept().await;
81        println!();
82        self.slept_how_long().await;
83        println!();
84        self.quality_of_sleep().await;
85    }
86
87    pub async fn last_slept(&mut self) {
88        #[derive(serde::Deserialize, yammer_derive::JsonSchema)]
89        struct LastSleptAnswer {
90            awake_hours: f64,
91            justification: String,
92        }
93        for idx in 0..3 {
94            let answer: LastSleptAnswer = match self
95                .question_and_answer(LAST_SLEPT, "When did you last wakeup? ")
96                .await
97            {
98                Ok(Some(answer)) => answer,
99                Ok(None) => {
100                    eprintln!("A blank answer is unacceptable (unless given three times).");
101                    continue;
102                }
103                Err(err) => {
104                    if idx < 2 {
105                        eprintln!("error: {err}\n\nPlease try again:\n\n");
106                    }
107                    continue;
108                }
109            };
110            let log_line = LogLine::LastSlept {
111                recorded_at: Local::now().fixed_offset().to_rfc3339(),
112                awake_hours: answer.awake_hours,
113                justification: answer.justification,
114            };
115            self.log(log_line);
116            return;
117        }
118        eprintln!("Could not interpret input, moving on...\n\n");
119    }
120
121    pub async fn slept_how_long(&mut self) {
122        #[derive(serde::Deserialize, yammer_derive::JsonSchema)]
123        struct SleptHowLongAnswer {
124            sleep_hours: f64,
125            justification: String,
126        }
127        for idx in 0..3 {
128            let answer: SleptHowLongAnswer = match self
129                .question_and_answer(
130                    SLEPT_HOW_LONG,
131                    "When you last slept, for how many hours did you sleep? ",
132                )
133                .await
134            {
135                Ok(Some(answer)) => answer,
136                Ok(None) => {
137                    eprintln!("A blank answer is unacceptable (unless given three times).");
138                    continue;
139                }
140                Err(err) => {
141                    if idx < 2 {
142                        eprintln!("error: {err}\n\nPlease try again:\n\n");
143                    }
144                    continue;
145                }
146            };
147            let log_line = LogLine::HoursSlept {
148                recorded_at: Local::now().fixed_offset().to_rfc3339(),
149                sleep_hours: answer.sleep_hours,
150                justification: answer.justification,
151            };
152            self.log(log_line);
153            return;
154        }
155        eprintln!("Could not interpret input, moving on...\n\n");
156    }
157
158    pub async fn quality_of_sleep(&mut self) {
159        #[derive(serde::Deserialize, yammer_derive::JsonSchema)]
160        struct QualityOfSleepAnswer {
161            answer: f64,
162            justification: String,
163        }
164        for idx in 0..3 {
165            let answer: QualityOfSleepAnswer = match self
166                .question_and_answer(
167                    QUALITY_OF_SLEEP,
168                    "How would you rate the quality of your last sleep schedule on a scale of 1 being the worst and 10 being the best? "
169                )
170                .await
171            {
172                Ok(Some(answer)) => answer,
173                Ok(None) => {
174                    eprintln!("A blank answer is unacceptable (unless given three times).");
175                    continue;
176                }
177                Err(err) => {
178                    if idx < 2 {
179                        eprintln!("error: {err}\n\nPlease try again:\n\n");
180                    }
181                    continue;
182                }
183            };
184            let log_line = LogLine::SleepQuality {
185                recorded_at: Local::now().fixed_offset().to_rfc3339(),
186                answer: answer.answer,
187                justification: answer.justification,
188            };
189            self.log(log_line);
190            return;
191        }
192        eprintln!("Could not interpret input, moving on...\n\n");
193    }
194
195    pub async fn medications(&mut self) {
196        #[derive(serde::Deserialize, yammer_derive::JsonSchema)]
197        struct MedicationAnswer {
198            substance: String,
199            quantity: f64,
200            units: String,
201            times_daily: f64,
202            justification: String,
203        }
204        let mut failures = 0;
205        loop {
206            let answer: MedicationAnswer = match self.question_and_answer(MEDICATION, 
207                    "List every medication you took since last checkin, and the dosage, e.g. \"30mg magic pill 3x daily.\"
208List caffeine, nicotine, and other substances as appropriate.
209Enter an empty line to continue: ").await {
210                Ok(Some(answer)) => answer,
211                Ok(None) => {
212                    return;
213                }
214                Err(err) => {
215                    if failures < 2 {
216                        eprintln!("error: {err}\n\nPlease try again:\n\n");
217                        continue;
218                    } else {
219                        eprintln!("Could not interpret input, moving on...\n\n");
220                        return;
221                    }
222                }
223            };
224            let log_line = LogLine::Medication {
225                recorded_at: Local::now().fixed_offset().to_rfc3339(),
226                substance: answer.substance,
227                dose: Dose::Daily {
228                    quantity: answer.quantity,
229                    units: answer.units,
230                    times_daily: answer.times_daily,
231                },
232                justification: answer.justification,
233            };
234            self.log(log_line);
235            println!();
236            failures = 0;
237        }
238    }
239
240    pub async fn hygiene(&mut self) {
241        let hygiene_system = self.load_system(HYGIENE);
242        let hygiene = self
243                .read_line(
244                    "Describe your hygiene in a way that translates to POOR, FAIR, GOOD, GREAT, EXCELLENT since your last report: "
245                )
246                .await;
247        let req = GenerateRequest {
248            model: self.model(),
249            prompt: hygiene,
250            format: Some(serde_json::json! {{
251                "type": "object",
252                "properties": {
253                    "answer": {
254                        "type": "string",
255                        "enum": [
256                            "POOR",
257                            "FAIR",
258                            "GOOD",
259                            "GREAT",
260                            "EXCELLENT"
261                        ]
262                    },
263                    "justification": {
264                        "type": "string"
265                    }
266                },
267                "required": [
268                  "answer",
269                  "justification"
270                ]
271            }}),
272            system: Some(hygiene_system),
273            suffix: None,
274            stream: Some(false),
275            images: None,
276            template: None,
277            raw: None,
278            keep_alive: None,
279            options: None,
280        };
281        let req = req.make_request(&self.ollama_host());
282        let spinner = Spinner::new();
283        spinner.start();
284        let resp = req
285            .send()
286            .await
287            .expect("encountered an error")
288            .error_for_status()
289            .expect("encountered an error")
290            .text()
291            .await
292            .expect("encountered an error");
293        let resp: yammer::GenerateResponse =
294            serde_json::from_str(&resp).expect("json should parse");
295        spinner.inhibit();
296        #[derive(serde::Deserialize)]
297        struct Justification {
298            answer: String,
299            justification: String,
300        }
301        let answer: Justification = match serde_json::from_str(&resp.response) {
302            Ok(json) => json,
303            Err(err) => {
304                eprintln!("Model gave bogus json: {err}");
305                return;
306            }
307        };
308        let log_line = LogLine::Hygiene {
309            recorded_at: Local::now().fixed_offset().to_rfc3339(),
310            hygiene: answer.answer,
311            justification: answer.justification,
312        };
313        self.log(log_line);
314    }
315
316    async fn question_and_answer<T: for<'a> serde::Deserialize<'a> + yammer::JsonSchema>(
317        &mut self,
318        system: &str,
319        question: &str,
320    ) -> Result<Option<T>, Error> {
321        let system = self.load_system(system);
322        let answer = self.read_line(question).await;
323        if answer.trim().is_empty() {
324            return Ok(None);
325        }
326        let req = GenerateRequest {
327            model: self.model(),
328            prompt: answer,
329            format: Some(T::json_schema()),
330            system: Some(system),
331            suffix: None,
332            stream: Some(false),
333            images: None,
334            template: None,
335            raw: None,
336            keep_alive: None,
337            options: None,
338        };
339        let req = req.make_request(&self.ollama_host());
340        let spinner = Spinner::new();
341        spinner.start();
342        let resp = req
343            .send()
344            .await?
345            .error_for_status()?
346            .json::<GenerateResponse>()
347            .await?;
348        spinner.inhibit();
349        Ok(Some(serde_json::from_str(&resp.response)?))
350    }
351
352    fn model(&self) -> String {
353        match std::env::var("NOTAPSYCH_MODEL") {
354            Ok(model) => model,
355            Err(_) => {
356                eprintln!("please set NOTAPSYCH_MODEL in your environment");
357                std::process::exit(13);
358            }
359        }
360    }
361
362    fn load_system(&self, slug: &str) -> String {
363        // TODO(rescrv):  Load from filesystem or a remote database?
364        match slug {
365            LAST_SLEPT => r#"Measure the time since the user reports they last wokeup.
366
367You are to provide your answer in hours, along with a justification in plain text.  Respond in
368JSON.
369
370To calculate this accurately, you must think step-by-step.  For example, if the user reports they
371last slept at 5:30am yesterday, and it is now 3:15pm today, first compute that there are 18.5 hours
372between 5:30am and midnight, and then 15.25 hours between midnight and now.  Add 18.5 + 15.25 to
373get 33.75 hours.  Double check your math by working in reverse, starting from now and computing backwards,
374
375Triple check your results by computing the roundup to the nearest hour at each end and then count
376the intervening hours.  For example, if the user reports they last woke at 7:25am and it is now
3775:45pm., round up 25 minutes to the hour to get 35 minutes, round 5:45pm down to the hour to get 45
378minutes (the number of minutes past the hour).  Then count that there are 9 hours between 8:00am
379and 5:00pm, for a total of 9 hours + 35 minutes + 45 minutes, or 10 hours, 20 minutes.
380
381When all three computations agree, report your results.
382
383"#
384            .to_string() + &format!("It is currently {}.", Local::now().to_rfc2822()),
385            SLEPT_HOW_LONG => r#"Measure the number of hours the user reports they slept during their most recent sleep cycle.
386
387Example:
388"8 hours" => {"sleep_hours": 8, "justification": "The user said they slept 8 hours."}
389"#.to_string(),
390            QUALITY_OF_SLEEP => r#"Interpret the user's response as a number on a scale from 0.0 to 10.0"#.to_string(),
391            MEDICATION => r#"Parse the amount of medication the user reports taking.
392
393Report -1 times-daily when there is not enough information to make a decision.
394
395Respond using JSON.
396
397Example generic:
39830mg of something once daily => {"medication": "something", "quantity": 30, "units": "mg", "times-daily": 1}
399
400Example branded, as needed (report once per day):
401Advil, 200mg, as needed => {"medication": "Advil", "quantity": 200, "units": "mg", "times-daily": 1}
402
403Example caffeine:
404Two cups of coffee per day => {"medication": "caffeine", "quantity": 150, "units": "mg", "times-daily": 2}
405
406Example alcohol (assume: 0.02 BAC per liquor shot, wine glass, or beer bottle, in a single sitting):
407Two shots of whisky and a shot of whiskey => {"medication": "alcohol", "quantity": 0.06, "units": "BAC", "times-daily": 1}
408One glass of red wine with dinner => {"medication": "alcohol", "quantity": 0.02, "units": "BAC", "times-daily": 1}
409A fifth of Jack, once a month => {"medication": "alcohol", "quantity": 0.48, "units": "BAC", "times-daily": 0.033333}
410A fifth of Jack => {"medication": "alcohol", "quantity": 0.48, "units": "BAC", "times-daily": -1}
411A thirty rack of bud with Tucker => {"medication": "alcohol", "quantity": 0.60, "units": "BAC", "times-daily": 1}
412
413Example nicotine:
414Smoke a pack a day => {"medication": "nicotine", "quantity": 10, "units": "mg", "times-daily": 20}
415Smoke two packs a day => {"medication": "nicotine", "quantity": 10, "units": "mg", "times-daily": 40}
416"#.to_string(),
417            HYGIENE => r#"Make a judgement call about the user's hygiene habits.
418
419Someone who showers and shaves every day has excellent hygiene.
420Someone who showers infrequently has poor hygiene.
421It is a spectrum of POOR, FAIR, GOOD, GREAT, EXCELLENT.
422"#.to_string(),
423            _ => panic!("logic error: {slug} not supported"),
424        }
425    }
426
427    fn ollama_host(&self) -> String {
428        match std::env::var("OLLAMA_HOST") {
429            Ok(model) => model,
430            Err(_) => {
431                eprintln!("please set OLLAMA_HOST in your environment");
432                std::process::exit(13);
433            }
434        }
435    }
436
437    fn log(&self, log_line: LogLine) {
438        let transcript = match std::env::var("NOTAPSYCH_TRANSCRIPT") {
439            Ok(transcript) => transcript,
440            Err(_) => {
441                eprintln!("please set NOTAPSYCH_TRANSCRIPT in your environment");
442                std::process::exit(13);
443            }
444        };
445        let mut log = OpenOptions::new()
446            .append(true)
447            .create(true)
448            .open(transcript)
449            .expect("could not open transcript for append");
450        log.write_all(
451            (serde_json::to_string(&log_line).expect("log line should always serialize") + "\n")
452                .as_bytes(),
453        )
454        .expect("could not append to log; it may be corrupt");
455    }
456
457    async fn read_line(&mut self, question: &str) -> String {
458        match self.editor.readline(question) {
459            Ok(line) => line,
460            Err(ReadlineError::Interrupted) | Err(ReadlineError::Eof) => {
461                std::process::exit(0);
462            }
463            Err(err) => {
464                eprintln!("could not read line: {}", err);
465                std::process::exit(13);
466            }
467        }
468    }
469}
470
471pub async fn do_it_all() {
472    let config = Config::builder()
473        .auto_add_history(true)
474        .edit_mode(EditMode::Vi)
475        .completion_type(CompletionType::List)
476        .check_cursor_position(true)
477        .max_history_size(1_000_000)
478        .expect("this should always work")
479        .history_ignore_dups(true)
480        .expect("this should always work")
481        .history_ignore_space(true)
482        .build();
483    let history = rustyline::history::FileHistory::new();
484    let mut rl = Editor::with_history(config, history).expect("this should always work");
485    let commands = vec![
486        CommandHint::new(":help", ":help"),
487        CommandHint::new(":exit", ":exit"),
488        CommandHint::new(":quit", ":quit"),
489    ];
490    let h = ShellHelper {
491        commands: commands.clone(),
492        hinter: HistoryHinter::new(),
493        hints: commands.clone(),
494    };
495    rl.set_helper(Some(h));
496    rl.bind_sequence(
497        KeyEvent::from('\t'),
498        EventHandler::Conditional(Box::new(TabEventHandler)),
499    );
500    let mut not_a_psych = NotAPsych { editor: rl };
501    not_a_psych.checkin().await;
502}
503
504#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
505#[serde(tag = "type")]
506pub enum LogLine {
507    #[serde(rename = "last-slept")]
508    LastSlept {
509        recorded_at: String,
510        awake_hours: f64,
511        justification: String,
512    },
513    #[serde(rename = "hours-slept")]
514    HoursSlept {
515        recorded_at: String,
516        sleep_hours: f64,
517        justification: String,
518    },
519    #[serde(rename = "sleep-quality")]
520    SleepQuality {
521        recorded_at: String,
522        answer: f64,
523        justification: String,
524    },
525    #[serde(rename = "medication")]
526    Medication {
527        recorded_at: String,
528        substance: String,
529        dose: Dose,
530        justification: String,
531    },
532    #[serde(rename = "hygiene")]
533    Hygiene {
534        recorded_at: String,
535        hygiene: String,
536        justification: String,
537    },
538}
539
540#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
541#[serde(tag = "type")]
542pub enum Dose {
543    #[serde(rename = "daily")]
544    Daily {
545        quantity: f64,
546        units: String,
547        times_daily: f64,
548    },
549}