extern crate alloc;
use core::ffi::c_void;
use alloc::ffi::CString;
use alloc::string::{String, ToString};
use crate::utils::zclean;
use crate::{ErrorKind, OrtResult, Response, ThinkEvent, Write, common::stats, common::utils};
use crate::{ort_error, syscall};
const CURSOR_ON: &[u8] = "\x1b[?25h".as_bytes();
const MSG_CONNECTING: &[u8] = "\x1b[?25lConnecting...\r".as_bytes();
const MSG_CLEAR_LINE: &[u8] = "\r\x1b[2K\n".as_bytes();
const MSG_PROCESSING: &[u8] = "\x1b[1mProcessing...\x1b[0m\r".as_bytes();
const MSG_THINK_TAG_END: &[u8] = "\x1b[1m</think>\x1b[0m\n".as_bytes();
const MSG_THINKING: &[u8] = "\x1b[1mThinking...\x1b[0m ".as_bytes();
const MSG_THINK_TAG_START: &[u8] = "\x1b[1m<think>\x1b[0m".as_bytes();
const SPINNER: [&[u8]; 4] = [
"|\x1b[1D".as_bytes(),
"/\x1b[1D".as_bytes(),
"-\x1b[1D".as_bytes(),
"\\\x1b[1D".as_bytes(),
];
const ERR_RATE_LIMITED: &str = "429 Too Many Requests";
pub trait OutputWriter {
fn write(&mut self, data: Response) -> OrtResult<()>;
fn stop(&mut self, include_stats: bool) -> OrtResult<()>;
}
pub struct ConsoleWriter<W: Write + Send> {
pub writer: W, pub show_reasoning: bool,
pub is_quiet: bool,
pub is_running: bool,
pub is_first_content: bool,
pub spindx: usize,
pub stats_out: Option<stats::Stats>,
}
impl<W: Write + Send> ConsoleWriter<W> {
pub fn new(writer: W, show_reasoning: bool, is_quiet: bool) -> ConsoleWriter<W> {
ConsoleWriter {
writer,
show_reasoning,
is_quiet,
is_running: false,
is_first_content: true,
spindx: 0,
stats_out: None,
}
}
}
impl<W: Write + Send> OutputWriter for ConsoleWriter<W> {
fn stop(&mut self, include_stats: bool) -> OrtResult<()> {
let _ = self.writer.write(CURSOR_ON);
let _ = self.writer.write(b"\n");
let _ = self.writer.flush();
if !include_stats || self.is_quiet {
return Ok(());
}
let Some(stats) = self.stats_out.take() else {
return Err(ort_error(ErrorKind::MissingUsageStats, ""));
};
let _ = self.writer.write("\nStats: ".as_bytes());
let _ = self.writer.write(stats.as_string().as_bytes());
let _ = self.writer.write_char('\n');
Ok(())
}
fn write(&mut self, data: Response) -> OrtResult<()> {
if !self.is_running {
let _ = self.writer.write(MSG_CONNECTING);
let _ = self.writer.flush();
self.is_running = true;
}
match data {
Response::Start => {
let _ = self.writer.write(MSG_PROCESSING);
let _ = self.writer.flush();
}
Response::Think(think) => {
if self.show_reasoning {
match think {
ThinkEvent::Start => {
let _ = self.writer.write(MSG_THINK_TAG_START);
}
ThinkEvent::Content(s) => {
let _ = self.writer.write_all(s.as_bytes());
let _ = self.writer.flush();
}
ThinkEvent::Stop => {
let _ = self.writer.write(MSG_THINK_TAG_END);
}
}
} else {
match think {
ThinkEvent::Start => {
let _ = self.writer.write(MSG_THINKING);
let _ = self.writer.flush();
}
ThinkEvent::Content(_) => {
let _ = self.writer.write(SPINNER[self.spindx % SPINNER.len()]);
let _ = self.writer.flush();
self.spindx += 1;
}
ThinkEvent::Stop => {}
}
}
}
Response::Content(content) => {
if self.is_first_content {
let _ = self.writer.write(MSG_CLEAR_LINE);
self.is_first_content = false;
}
let _ = self.writer.write_all(content.as_bytes());
let _ = self.writer.flush();
}
Response::Stats(stats) => {
self.stats_out = Some(stats);
}
Response::Error(err_string) => {
let _ = self.writer.write(CURSOR_ON);
let _ = self.writer.flush();
if err_string.contains(ERR_RATE_LIMITED) {
return Err(ort_error(ErrorKind::RateLimited, ""));
}
utils::print_string(c"\nERROR: ", &err_string);
return Err(ort_error(
ErrorKind::ResponseStreamError,
"OpenRouter returned an error",
));
}
Response::None => {
panic!("Response::None means we read the wrong Queue position");
}
}
Ok(())
}
}
pub struct FileWriter<W: Write + Send> {
pub writer: W,
pub show_reasoning: bool,
pub is_quiet: bool,
pub stats_out: Option<stats::Stats>,
}
impl<W: Write + Send> FileWriter<W> {
pub fn new(writer: W, show_reasoning: bool, is_quiet: bool) -> FileWriter<W> {
FileWriter {
writer,
show_reasoning,
is_quiet,
stats_out: None,
}
}
}
impl<W: Write + Send> OutputWriter for FileWriter<W> {
fn write(&mut self, data: Response) -> OrtResult<()> {
match data {
Response::Start => {}
Response::Think(think) => {
if self.show_reasoning {
match think {
ThinkEvent::Start => {
let _ = self.writer.write("<think>".as_bytes());
}
ThinkEvent::Content(s) => {
let _ = self.writer.write_all(s.as_bytes());
}
ThinkEvent::Stop => {
let _ = self.writer.write("</think>\n\n".as_bytes());
}
}
}
}
Response::Content(content) => {
let _ = self.writer.write_all(content.as_bytes());
}
Response::Stats(stats) => {
self.stats_out = Some(stats);
}
Response::Error(mut err_string) => {
if err_string.contains(ERR_RATE_LIMITED) {
return Err(ort_error(ErrorKind::RateLimited, ""));
}
let c_s = CString::new("\nERROR: ".to_string() + zclean(&mut err_string)).unwrap();
syscall::write(2, c_s.as_ptr().cast(), c_s.count_bytes());
return Err(ort_error(
ErrorKind::ResponseStreamError,
"OpenRouter returned an error",
));
}
Response::None => {
return Err(ort_error(
ErrorKind::QueueDesync,
"Response::None means we read the wrong Queue position",
));
}
}
Ok(())
}
fn stop(&mut self, include_stats: bool) -> OrtResult<()> {
let _ = self.writer.write(b"\n");
if !include_stats || self.is_quiet {
return Ok(());
}
let Some(stats) = self.stats_out.take() else {
return Err(ort_error(ErrorKind::MissingUsageStats, ""));
};
let _ = self.writer.write("\nStats: ".as_bytes());
let _ = self.writer.write(stats.as_string().as_bytes());
let _ = self.writer.write_char('\n');
Ok(())
}
}
pub struct CollectedWriter {
contents: String,
got_stats: Option<stats::Stats>,
pub output: Option<String>,
}
impl CollectedWriter {
pub fn new() -> Self {
Self {
got_stats: None,
contents: String::with_capacity(4096),
output: None,
}
}
}
impl OutputWriter for CollectedWriter {
fn write(&mut self, data: Response) -> OrtResult<()> {
match data {
Response::Start => {}
Response::Think(_) => {}
Response::Content(content) => {
self.contents.push_str(&content);
}
Response::Stats(stats) => {
self.got_stats = Some(stats);
}
Response::Error(_err) => {
return Err(ort_error(
ErrorKind::ResponseStreamError,
"CollectedWriter response error",
));
}
Response::None => {
return Err(ort_error(
ErrorKind::QueueDesync,
"Response::None means we read the wrong Queue position",
));
}
}
Ok(())
}
fn stop(&mut self, _include_stats: bool) -> OrtResult<()> {
let stat_string = self.got_stats.take().unwrap().as_string();
let mut out = String::with_capacity(stat_string.len() + self.contents.len() + 9);
out.push_str("--- ");
out.push_str(&stat_string);
out.push_str(" ---\n");
out.push_str(&self.contents);
self.output = Some(out);
Ok(())
}
}
pub struct StdoutWriter {}
impl Write for StdoutWriter {
fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
let bytes_written = syscall::write(1, buf.as_ptr() as *const c_void, buf.len());
if bytes_written >= 0 {
Ok(bytes_written as usize)
} else {
Err(ort_error(ErrorKind::StdoutWriteFailed, ""))
}
}
fn flush(&mut self) -> OrtResult<()> {
Ok(())
}
}