Skip to main content

ort_openrouter_cli/output/
writer.rs

1//! ort: Open Router CLI
2//! https://github.com/grahamking/ort
3//!
4//! MIT License
5//! Copyright (c) 2025 Graham King
6
7extern crate alloc;
8use core::ffi::c_void;
9
10use alloc::ffi::CString;
11use alloc::string::{String, ToString};
12use alloc::vec::Vec;
13
14use crate::utils::zclean;
15use crate::{
16    Context, ErrorKind, LastData, Message, OrtResult, PromptOpts, Response, ThinkEvent, Write,
17    common::config, common::file, common::queue, common::stats, common::utils,
18};
19use crate::{libc, ort_error};
20
21const CURSOR_ON: &[u8] = "\x1b[?25h".as_bytes();
22
23//const CURSOR_OFF: &str = "\x1b[?25l";
24const MSG_CONNECTING: &[u8] = "\x1b[?25lConnecting...\r".as_bytes();
25
26// \r{CLEAR_LINE}\n
27const MSG_CLEAR_LINE: &[u8] = "\r\x1b[2K\n".as_bytes();
28
29// These are all surrounded by BOLD_START and BOLD_END, but I can't find a way to
30// do string concatenation at build time with constants
31//const BOLD_START: &str = "\x1b[1m";
32//const BOLD_END: &str = "\x1b[0m";
33const MSG_PROCESSING: &[u8] = "\x1b[1mProcessing...\x1b[0m\r".as_bytes();
34const MSG_THINK_TAG_END: &[u8] = "\x1b[1m</think>\x1b[0m\n".as_bytes();
35const MSG_THINKING: &[u8] = "\x1b[1mThinking...\x1b[0m  ".as_bytes();
36const MSG_THINK_TAG_START: &[u8] = "\x1b[1m<think>\x1b[0m".as_bytes();
37
38// The spinner displays a sequence of these characters: | / - \ , which when
39// animated look like they are spinning.
40// The array includes the ANSI escape to move back one character after each one
41// is printed, so they overwrite each other.
42//const BACK_ONE: &[u8] = "\x1b[1D".as_bytes();
43const SPINNER: [&[u8]; 4] = [
44    "|\x1b[1D".as_bytes(),
45    "/\x1b[1D".as_bytes(),
46    "-\x1b[1D".as_bytes(),
47    "\\\x1b[1D".as_bytes(),
48];
49
50const ERR_RATE_LIMITED: &str = "429 Too Many Requests";
51
52pub struct ConsoleWriter<W: Write + Send> {
53    pub writer: W, // Must handle ANSI control chars
54    pub show_reasoning: bool,
55}
56
57impl<W: Write + Send> ConsoleWriter<W> {
58    pub fn into_inner(self) -> W {
59        self.writer
60    }
61    pub fn run<const N: usize>(
62        &mut self,
63        mut rx: queue::Consumer<Response, N>,
64    ) -> OrtResult<stats::Stats> {
65        let _ = self.writer.write(MSG_CONNECTING);
66        let _ = self.writer.flush();
67
68        let mut is_first_content = true;
69        let mut spindx = 0;
70        let mut stats_out = None;
71        while let Some(data) = rx.get_next() {
72            match data {
73                Response::Start => {
74                    let _ = self.writer.write(MSG_PROCESSING);
75                    let _ = self.writer.flush();
76                }
77                Response::Think(think) => {
78                    if self.show_reasoning {
79                        match think {
80                            ThinkEvent::Start => {
81                                let _ = self.writer.write(MSG_THINK_TAG_START);
82                            }
83                            ThinkEvent::Content(s) => {
84                                let _ = self.writer.write_all(s.as_bytes());
85                                let _ = self.writer.flush();
86                            }
87                            ThinkEvent::Stop => {
88                                let _ = self.writer.write(MSG_THINK_TAG_END);
89                            }
90                        }
91                    } else {
92                        match think {
93                            ThinkEvent::Start => {
94                                let _ = self.writer.write(MSG_THINKING);
95                                let _ = self.writer.flush();
96                            }
97                            ThinkEvent::Content(_) => {
98                                let _ = self.writer.write(SPINNER[spindx % SPINNER.len()]);
99                                let _ = self.writer.flush();
100                                spindx += 1;
101                            }
102                            ThinkEvent::Stop => {}
103                        }
104                    }
105                }
106                Response::Content(content) => {
107                    if is_first_content {
108                        // Erase the Processing or Thinking line
109                        let _ = self.writer.write(MSG_CLEAR_LINE);
110                        is_first_content = false;
111                    }
112                    let _ = self.writer.write_all(content.as_bytes());
113                    let _ = self.writer.flush();
114                }
115                Response::Stats(stats) => {
116                    stats_out = Some(stats);
117                }
118                Response::Error(err_string) => {
119                    let _ = self.writer.write(CURSOR_ON);
120                    let _ = self.writer.flush();
121                    if err_string.contains(ERR_RATE_LIMITED) {
122                        return Err(ort_error(ErrorKind::RateLimited, ""));
123                    }
124                    utils::print_string(c"\nERROR: ", &err_string);
125                    return Err(ort_error(
126                        ErrorKind::ResponseStreamError,
127                        "OpenRouter returned an error",
128                    ));
129                }
130                Response::None => {
131                    panic!("Response::None means we read the wrong Queue position");
132                }
133            }
134        }
135
136        let _ = self.writer.write(CURSOR_ON);
137        let _ = self.writer.flush();
138
139        let Some(stats) = stats_out else {
140            return Err(ort_error(ErrorKind::MissingUsageStats, ""));
141        };
142        Ok(stats)
143    }
144}
145
146pub struct FileWriter<W: Write + Send> {
147    pub writer: W,
148    pub show_reasoning: bool,
149}
150
151impl<W: Write + Send> FileWriter<W> {
152    pub fn into_inner(self) -> W {
153        self.writer
154    }
155    pub fn run<const N: usize>(
156        &mut self,
157        mut rx: queue::Consumer<Response, N>,
158    ) -> OrtResult<stats::Stats> {
159        let mut stats_out = None;
160        while let Some(data) = rx.get_next() {
161            match data {
162                Response::Start => {}
163                Response::Think(think) => {
164                    if self.show_reasoning {
165                        match think {
166                            ThinkEvent::Start => {
167                                let _ = self.writer.write("<think>".as_bytes());
168                            }
169                            ThinkEvent::Content(s) => {
170                                let _ = self.writer.write_all(s.as_bytes());
171                            }
172                            ThinkEvent::Stop => {
173                                let _ = self.writer.write("</think>\n\n".as_bytes());
174                            }
175                        }
176                    }
177                }
178                Response::Content(content) => {
179                    let _ = self.writer.write_all(content.as_bytes());
180                }
181                Response::Stats(stats) => {
182                    stats_out = Some(stats);
183                }
184                Response::Error(mut err_string) => {
185                    if err_string.contains(ERR_RATE_LIMITED) {
186                        return Err(ort_error(ErrorKind::RateLimited, ""));
187                    }
188                    let c_s =
189                        CString::new("\nERROR: ".to_string() + zclean(&mut err_string)).unwrap();
190                    unsafe {
191                        libc::write(2, c_s.as_ptr().cast(), c_s.count_bytes());
192                    }
193                    return Err(ort_error(
194                        ErrorKind::ResponseStreamError,
195                        "OpenRouter returned an error",
196                    ));
197                }
198                Response::None => {
199                    return Err(ort_error(
200                        ErrorKind::QueueDesync,
201                        "Response::None means we read the wrong Queue position",
202                    ));
203                }
204            }
205        }
206
207        let Some(stats) = stats_out else {
208            return Err(ort_error(ErrorKind::MissingUsageStats, ""));
209        };
210        Ok(stats)
211    }
212}
213
214pub struct CollectedWriter {}
215
216impl CollectedWriter {
217    pub fn run<const N: usize>(
218        &mut self,
219        mut rx: queue::Consumer<Response, N>,
220    ) -> OrtResult<String> {
221        let mut got_stats = None;
222        let mut contents = String::with_capacity(4096);
223        while let Some(data) = rx.get_next() {
224            match data {
225                Response::Start => {}
226                Response::Think(_) => {}
227                Response::Content(content) => {
228                    contents.push_str(&content);
229                }
230                Response::Stats(stats) => {
231                    got_stats = Some(stats);
232                }
233                Response::Error(_err) => {
234                    // Original message: CollectedWriter + err detail
235                    return Err(ort_error(
236                        ErrorKind::ResponseStreamError,
237                        "CollectedWriter response error",
238                    ));
239                }
240                Response::None => {
241                    return Err(ort_error(
242                        ErrorKind::QueueDesync,
243                        "Response::None means we read the wrong Queue position",
244                    ));
245                }
246            }
247        }
248
249        let stat_string = got_stats.unwrap().as_string();
250        let mut out = String::with_capacity(stat_string.len() + contents.len() + 9);
251        out.push_str("--- ");
252        out.push_str(&stat_string);
253        out.push_str(" ---\n");
254        out.push_str(&contents);
255        Ok(out)
256    }
257}
258
259pub struct LastWriter {
260    w: file::File,
261    data: LastData,
262}
263
264impl LastWriter {
265    pub fn new(opts: PromptOpts, messages: Vec<Message>) -> OrtResult<Self> {
266        let mut last_path = [0u8; 128];
267        let idx = config::cache_dir(&mut last_path)?;
268        last_path[idx] = b'/';
269        let last_filename = utils::last_filename();
270        let start = idx + 1;
271        let end = start + last_filename.len();
272        last_path[start..end].copy_from_slice(last_filename.as_bytes());
273        // end + 1 to add a null byte on the end
274        let last_file =
275            unsafe { file::File::create(&last_path[..end + 1]).context("create last file")? };
276        let data = LastData { opts, messages };
277        Ok(LastWriter { data, w: last_file })
278    }
279
280    pub fn run<const N: usize>(
281        &mut self,
282        mut rx: queue::Consumer<Response, N>,
283    ) -> OrtResult<stats::Stats> {
284        // This will contain the entire model response. Start with a size that includes most
285        // answers, but allow realloc. Maybe we should stream to disk?
286        let mut contents = String::with_capacity(4096);
287        while let Some(data) = rx.get_next() {
288            match data {
289                Response::Start => {}
290                Response::Think(_) => {}
291                Response::Content(content) => {
292                    contents.push_str(&content);
293                }
294                Response::Stats(stats) => {
295                    self.data.opts.provider = Some(utils::slug(stats.provider()));
296                }
297                Response::Error(_err) => {
298                    return Err(ort_error(
299                        ErrorKind::LastWriterError,
300                        "LastWriter run error",
301                    ));
302                }
303                Response::None => {
304                    return Err(ort_error(
305                        ErrorKind::QueueDesync,
306                        "Response::None means we read the wrong Queue position",
307                    ));
308                }
309            }
310        }
311
312        let message = Message::assistant(contents);
313        self.data.messages.push(message);
314
315        self.data.to_json_writer(&mut self.w)?;
316        let _ = (&mut self.w).flush();
317
318        Ok(stats::Stats::default()) // Stats is not used
319    }
320}
321
322pub struct StdoutWriter {}
323
324impl Write for StdoutWriter {
325    fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
326        let bytes_written = unsafe { libc::write(1, buf.as_ptr() as *const c_void, buf.len()) };
327        if bytes_written >= 0 {
328            Ok(bytes_written as usize)
329        } else {
330            Err(ort_error(ErrorKind::StdoutWriteFailed, ""))
331        }
332    }
333
334    fn flush(&mut self) -> OrtResult<()> {
335        Ok(())
336    }
337}