use std::error::Error;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use clap::{Subcommand, ValueEnum};
use lino_arguments::Parser;
use formal_ai::{
create_chat_completion, create_response, knowledge_links_notation, run_telegram_polling,
run_telegram_webhook_server, ChatCompletionRequest, ChatMessage, FormalAiEngine,
MessageContent, ResponsesRequest, TelegramPollingConfig, DEFAULT_MODEL,
};
#[derive(Parser, Debug)]
#[command(name = "formal-ai", about = "Formal symbolic AI proof of concept")]
struct Args {
#[command(subcommand)]
command: Option<Command>,
}
#[derive(Debug, Subcommand)]
enum Command {
Chat {
#[arg(long, env = "FORMAL_AI_PROMPT")]
prompt: String,
#[arg(long, value_enum, default_value_t = OutputFormat::Text)]
format: OutputFormat,
},
Dataset,
Serve {
#[arg(long, env = "FORMAL_AI_HOST", default_value = "127.0.0.1")]
host: String,
#[arg(long, env = "FORMAL_AI_PORT", default_value_t = 8080)]
port: u16,
},
Telegram {
#[arg(
long,
value_enum,
env = "FORMAL_AI_TELEGRAM_MODE",
default_value_t = TelegramMode::Polling
)]
mode: TelegramMode,
#[arg(long, env = "TELEGRAM_BOT_TOKEN")]
token: Option<String>,
#[arg(
long,
env = "FORMAL_AI_TELEGRAM_API_BASE",
default_value = "https://api.telegram.org"
)]
api_base: String,
#[arg(long, env = "FORMAL_AI_TELEGRAM_TIMEOUT", default_value_t = 30)]
timeout: u32,
#[arg(long, env = "FORMAL_AI_TELEGRAM_LIMIT", default_value_t = 100)]
limit: u32,
#[arg(long, env = "FORMAL_AI_TELEGRAM_ALLOWED_UPDATES", default_value = "")]
allowed_updates: String,
#[arg(long, env = "FORMAL_AI_HOST", default_value = "127.0.0.1")]
host: String,
#[arg(long, env = "FORMAL_AI_PORT", default_value_t = 8080)]
port: u16,
},
}
#[derive(Debug, Clone, Copy, ValueEnum)]
enum OutputFormat {
Text,
Chat,
Responses,
}
#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)]
enum TelegramMode {
Polling,
Webhook,
}
impl std::fmt::Display for OutputFormat {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Text => formatter.write_str("text"),
Self::Chat => formatter.write_str("chat"),
Self::Responses => formatter.write_str("responses"),
}
}
}
impl std::fmt::Display for TelegramMode {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Polling => formatter.write_str("polling"),
Self::Webhook => formatter.write_str("webhook"),
}
}
}
fn main() -> Result<(), Box<dyn Error>> {
lino_arguments::init();
let args = Args::parse();
let command = args.command.unwrap_or_else(|| Command::Chat {
prompt: String::from("Hi"),
format: OutputFormat::Text,
});
match command {
Command::Chat { prompt, format } => run_chat(&prompt, format)?,
Command::Dataset => println!("{}", knowledge_links_notation()),
Command::Serve { host, port } => run_telegram_webhook_server(&format!("{host}:{port}"))?,
Command::Telegram {
mode,
token,
api_base,
timeout,
limit,
allowed_updates,
host,
port,
} => run_telegram(TelegramRunArgs {
mode,
token,
api_base,
timeout,
limit,
allowed_updates,
host,
port,
})?,
}
Ok(())
}
struct TelegramRunArgs {
mode: TelegramMode,
token: Option<String>,
api_base: String,
timeout: u32,
limit: u32,
allowed_updates: String,
host: String,
port: u16,
}
fn run_chat(prompt: &str, format: OutputFormat) -> Result<(), Box<dyn Error>> {
match format {
OutputFormat::Text => {
let response = FormalAiEngine.answer(prompt);
println!("{}", response.answer);
}
OutputFormat::Chat => {
let request = ChatCompletionRequest {
model: Some(String::from(DEFAULT_MODEL)),
messages: vec![ChatMessage {
role: String::from("user"),
content: MessageContent::Text(String::from(prompt)),
}],
stream: false,
};
println!(
"{}",
serde_json::to_string_pretty(&create_chat_completion(&request))?
);
}
OutputFormat::Responses => {
let request = ResponsesRequest {
model: Some(String::from(DEFAULT_MODEL)),
input: serde_json::Value::String(String::from(prompt)),
instructions: None,
stream: false,
};
println!(
"{}",
serde_json::to_string_pretty(&create_response(&request))?
);
}
}
Ok(())
}
fn run_telegram(args: TelegramRunArgs) -> Result<(), Box<dyn Error>> {
match args.mode {
TelegramMode::Polling => {
let token = args.token.ok_or_else(|| {
String::from(
"Telegram polling mode requires a bot token. \
Pass --token or set TELEGRAM_BOT_TOKEN.",
)
})?;
let mut config = TelegramPollingConfig::new(token);
config.api_base = args.api_base;
config.timeout_seconds = args.timeout;
config.limit = args.limit.clamp(1, 100);
config.allowed_updates = parse_allowed_updates(&args.allowed_updates);
run_telegram_polling(&config, None, Arc::new(AtomicBool::new(false)))?;
}
TelegramMode::Webhook => {
run_telegram_webhook_server(&format!(
"{host}:{port}",
host = args.host,
port = args.port
))?;
}
}
Ok(())
}
fn parse_allowed_updates(raw: &str) -> Vec<String> {
raw.split(',')
.map(str::trim)
.filter(|entry| !entry.is_empty())
.map(ToOwned::to_owned)
.collect()
}