Skip to main content

ort_openrouter_cli/output/
writer.rs

1//! ort: Open Router CLI
2//! https://github.com/grahamking/ort
3//!
4//! MIT License
5//! Copyright (c) 2025 Graham King
6
7extern crate alloc;
8use core::ffi::c_void;
9
10use alloc::ffi::CString;
11use alloc::string::{String, ToString};
12
13use crate::utils::zclean;
14use crate::{ErrorKind, OrtResult, Response, ThinkEvent, Write, common::stats, common::utils};
15use crate::{ort_error, syscall};
16
17const CURSOR_ON: &[u8] = "\x1b[?25h".as_bytes();
18
19//const CURSOR_OFF: &str = "\x1b[?25l";
20const MSG_CONNECTING: &[u8] = "\x1b[?25lConnecting...\r".as_bytes();
21
22// \r{CLEAR_LINE}\n
23const MSG_CLEAR_LINE: &[u8] = "\r\x1b[2K\n".as_bytes();
24
25// These are all surrounded by BOLD_START and BOLD_END, but I can't find a way to
26// do string concatenation at build time with constants
27//const BOLD_START: &str = "\x1b[1m";
28//const BOLD_END: &str = "\x1b[0m";
29const MSG_PROCESSING: &[u8] = "\x1b[1mProcessing...\x1b[0m\r".as_bytes();
30const MSG_THINK_TAG_END: &[u8] = "\x1b[1m</think>\x1b[0m\n".as_bytes();
31const MSG_THINKING: &[u8] = "\x1b[1mThinking...\x1b[0m  ".as_bytes();
32const MSG_THINK_TAG_START: &[u8] = "\x1b[1m<think>\x1b[0m".as_bytes();
33
34// The spinner displays a sequence of these characters: | / - \ , which when
35// animated look like they are spinning.
36// The array includes the ANSI escape to move back one character after each one
37// is printed, so they overwrite each other.
38//const BACK_ONE: &[u8] = "\x1b[1D".as_bytes();
39const SPINNER: [&[u8]; 4] = [
40    "|\x1b[1D".as_bytes(),
41    "/\x1b[1D".as_bytes(),
42    "-\x1b[1D".as_bytes(),
43    "\\\x1b[1D".as_bytes(),
44];
45
46const ERR_RATE_LIMITED: &str = "429 Too Many Requests";
47
48pub trait OutputWriter {
49    fn write(&mut self, data: Response) -> OrtResult<()>;
50    fn stop(&mut self, include_stats: bool) -> OrtResult<()>;
51}
52
53pub struct ConsoleWriter<W: Write + Send> {
54    pub writer: W, // Must handle ANSI control chars
55    pub show_reasoning: bool,
56    pub is_quiet: bool,
57    pub is_running: bool,
58    pub is_first_content: bool,
59    pub spindx: usize,
60    pub stats_out: Option<stats::Stats>,
61}
62
63impl<W: Write + Send> ConsoleWriter<W> {
64    pub fn new(writer: W, show_reasoning: bool, is_quiet: bool) -> ConsoleWriter<W> {
65        ConsoleWriter {
66            writer,
67            show_reasoning,
68            is_quiet,
69            is_running: false,
70            is_first_content: true,
71            spindx: 0,
72            stats_out: None,
73        }
74    }
75}
76
77impl<W: Write + Send> OutputWriter for ConsoleWriter<W> {
78    fn stop(&mut self, include_stats: bool) -> OrtResult<()> {
79        let _ = self.writer.write(CURSOR_ON);
80        let _ = self.writer.write(b"\n");
81        let _ = self.writer.flush();
82        if !include_stats || self.is_quiet {
83            return Ok(());
84        }
85
86        let Some(stats) = self.stats_out.take() else {
87            return Err(ort_error(ErrorKind::MissingUsageStats, ""));
88        };
89        let _ = self.writer.write("\nStats: ".as_bytes());
90        let _ = self.writer.write(stats.as_string().as_bytes());
91        let _ = self.writer.write_char('\n');
92
93        Ok(())
94    }
95
96    fn write(&mut self, data: Response) -> OrtResult<()> {
97        if !self.is_running {
98            let _ = self.writer.write(MSG_CONNECTING);
99            let _ = self.writer.flush();
100            self.is_running = true;
101        }
102
103        match data {
104            Response::Start => {
105                let _ = self.writer.write(MSG_PROCESSING);
106                let _ = self.writer.flush();
107            }
108            Response::Think(think) => {
109                if self.show_reasoning {
110                    match think {
111                        ThinkEvent::Start => {
112                            let _ = self.writer.write(MSG_THINK_TAG_START);
113                        }
114                        ThinkEvent::Content(s) => {
115                            let _ = self.writer.write_all(s.as_bytes());
116                            let _ = self.writer.flush();
117                        }
118                        ThinkEvent::Stop => {
119                            let _ = self.writer.write(MSG_THINK_TAG_END);
120                        }
121                    }
122                } else {
123                    match think {
124                        ThinkEvent::Start => {
125                            let _ = self.writer.write(MSG_THINKING);
126                            let _ = self.writer.flush();
127                        }
128                        ThinkEvent::Content(_) => {
129                            let _ = self.writer.write(SPINNER[self.spindx % SPINNER.len()]);
130                            let _ = self.writer.flush();
131                            self.spindx += 1;
132                        }
133                        ThinkEvent::Stop => {}
134                    }
135                }
136            }
137            Response::Content(content) => {
138                if self.is_first_content {
139                    // Erase the Processing or Thinking line
140                    let _ = self.writer.write(MSG_CLEAR_LINE);
141                    self.is_first_content = false;
142                }
143                let _ = self.writer.write_all(content.as_bytes());
144                let _ = self.writer.flush();
145            }
146            Response::Stats(stats) => {
147                self.stats_out = Some(stats);
148            }
149            Response::Error(err_string) => {
150                let _ = self.writer.write(CURSOR_ON);
151                let _ = self.writer.flush();
152                if err_string.contains(ERR_RATE_LIMITED) {
153                    return Err(ort_error(ErrorKind::RateLimited, ""));
154                }
155                utils::print_string(c"\nERROR: ", &err_string);
156                return Err(ort_error(
157                    ErrorKind::ResponseStreamError,
158                    "OpenRouter returned an error",
159                ));
160            }
161            Response::None => {
162                panic!("Response::None means we read the wrong Queue position");
163            }
164        }
165
166        Ok(())
167    }
168}
169
170pub struct FileWriter<W: Write + Send> {
171    pub writer: W,
172    pub show_reasoning: bool,
173    pub is_quiet: bool,
174    pub stats_out: Option<stats::Stats>,
175}
176
177impl<W: Write + Send> FileWriter<W> {
178    pub fn new(writer: W, show_reasoning: bool, is_quiet: bool) -> FileWriter<W> {
179        FileWriter {
180            writer,
181            show_reasoning,
182            is_quiet,
183            stats_out: None,
184        }
185    }
186}
187
188impl<W: Write + Send> OutputWriter for FileWriter<W> {
189    fn write(&mut self, data: Response) -> OrtResult<()> {
190        match data {
191            Response::Start => {}
192            Response::Think(think) => {
193                if self.show_reasoning {
194                    match think {
195                        ThinkEvent::Start => {
196                            let _ = self.writer.write("<think>".as_bytes());
197                        }
198                        ThinkEvent::Content(s) => {
199                            let _ = self.writer.write_all(s.as_bytes());
200                        }
201                        ThinkEvent::Stop => {
202                            let _ = self.writer.write("</think>\n\n".as_bytes());
203                        }
204                    }
205                }
206            }
207            Response::Content(content) => {
208                let _ = self.writer.write_all(content.as_bytes());
209            }
210            Response::Stats(stats) => {
211                self.stats_out = Some(stats);
212            }
213            Response::Error(mut err_string) => {
214                if err_string.contains(ERR_RATE_LIMITED) {
215                    return Err(ort_error(ErrorKind::RateLimited, ""));
216                }
217                let c_s = CString::new("\nERROR: ".to_string() + zclean(&mut err_string)).unwrap();
218                syscall::write(2, c_s.as_ptr().cast(), c_s.count_bytes());
219                return Err(ort_error(
220                    ErrorKind::ResponseStreamError,
221                    "OpenRouter returned an error",
222                ));
223            }
224            Response::None => {
225                return Err(ort_error(
226                    ErrorKind::QueueDesync,
227                    "Response::None means we read the wrong Queue position",
228                ));
229            }
230        }
231        Ok(())
232    }
233
234    fn stop(&mut self, include_stats: bool) -> OrtResult<()> {
235        let _ = self.writer.write(b"\n");
236        if !include_stats || self.is_quiet {
237            return Ok(());
238        }
239
240        let Some(stats) = self.stats_out.take() else {
241            return Err(ort_error(ErrorKind::MissingUsageStats, ""));
242        };
243        let _ = self.writer.write("\nStats: ".as_bytes());
244        let _ = self.writer.write(stats.as_string().as_bytes());
245        let _ = self.writer.write_char('\n');
246        Ok(())
247    }
248}
249
250pub struct CollectedWriter {
251    contents: String,
252    got_stats: Option<stats::Stats>,
253    pub output: Option<String>,
254}
255
256impl CollectedWriter {
257    pub fn new() -> Self {
258        Self {
259            got_stats: None,
260            contents: String::with_capacity(4096),
261            output: None,
262        }
263    }
264}
265
266impl OutputWriter for CollectedWriter {
267    fn write(&mut self, data: Response) -> OrtResult<()> {
268        match data {
269            Response::Start => {}
270            Response::Think(_) => {}
271            Response::Content(content) => {
272                self.contents.push_str(&content);
273            }
274            Response::Stats(stats) => {
275                self.got_stats = Some(stats);
276            }
277            Response::Error(_err) => {
278                // Original message: CollectedWriter + err detail
279                return Err(ort_error(
280                    ErrorKind::ResponseStreamError,
281                    "CollectedWriter response error",
282                ));
283            }
284            Response::None => {
285                return Err(ort_error(
286                    ErrorKind::QueueDesync,
287                    "Response::None means we read the wrong Queue position",
288                ));
289            }
290        }
291        Ok(())
292    }
293
294    fn stop(&mut self, _include_stats: bool) -> OrtResult<()> {
295        let stat_string = self.got_stats.take().unwrap().as_string();
296        let mut out = String::with_capacity(stat_string.len() + self.contents.len() + 9);
297        out.push_str("--- ");
298        out.push_str(&stat_string);
299        out.push_str(" ---\n");
300        out.push_str(&self.contents);
301
302        self.output = Some(out);
303        Ok(())
304    }
305}
306
307pub struct StdoutWriter {}
308
309impl Write for StdoutWriter {
310    fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
311        let bytes_written = syscall::write(1, buf.as_ptr() as *const c_void, buf.len());
312        if bytes_written >= 0 {
313            Ok(bytes_written as usize)
314        } else {
315            Err(ort_error(ErrorKind::StdoutWriteFailed, ""))
316        }
317    }
318
319    fn flush(&mut self) -> OrtResult<()> {
320        Ok(())
321    }
322}