use rhai::{Array, Engine, EvalAltResult};
use std::path::Path;
use crate::client::AiClient;
use crate::embedding::EmbeddingModel;
use crate::model::Model;
use crate::transcription::{TranscriptionModel, TranscriptionOptions};
use crate::types::Message;
fn to_rhai_error<E: std::fmt::Display>(msg: &str, e: E) -> Box<EvalAltResult> {
Box::new(EvalAltResult::ErrorRuntime(
format!("{}: {}", msg, e).into(),
rhai::Position::NONE,
))
}
fn parse_model(model: &str) -> Result<Model, Box<EvalAltResult>> {
match model.to_lowercase().replace("-", "_").as_str() {
"llama3_3_70b" | "llama3.3_70b" => Ok(Model::Llama3_3_70B),
"llama3_1_70b" | "llama3.1_70b" => Ok(Model::Llama3_1_70B),
"llama3_1_8b" | "llama3.1_8b" => Ok(Model::Llama3_1_8B),
"qwen2_5_coder_32b" | "qwen2.5_coder_32b" => Ok(Model::Qwen2_5Coder32B),
"deepseek_coder_v2_5" | "deepseek_coder" => Ok(Model::DeepSeekCoderV2_5),
"deepseek_v3" | "deepseek" => Ok(Model::DeepSeekV3),
"llama3_1_405b" | "llama3.1_405b" => Ok(Model::Llama3_1_405B),
"mixtral_8x7b" | "mixtral" => Ok(Model::Mixtral8x7B),
"llama3_2_90b_vision" | "llama3.2_90b_vision" => Ok(Model::Llama3_2_90BVision),
"llama3_2_11b_vision" | "llama3.2_11b_vision" => Ok(Model::Llama3_2_11BVision),
"nemotron_nano_30b" | "nemotron" => Ok(Model::NemotronNano30B),
"gpt_oss_120b" | "gpt_oss" | "gptoss" => Ok(Model::GptOss120B),
_ => Err(Box::new(EvalAltResult::ErrorRuntime(
format!("Unknown model: {}. Available: llama3_3_70b, llama3_1_70b, llama3_1_8b, qwen2_5_coder_32b, deepseek_v3, mixtral_8x7b, nemotron_nano_30b, gpt_oss_120b", model).into(),
rhai::Position::NONE,
))),
}
}
fn parse_embedding_model(model: &str) -> Result<EmbeddingModel, Box<EvalAltResult>> {
match model.to_lowercase().replace("-", "_").as_str() {
"text_embedding_3_small" | "openai_small" => Ok(EmbeddingModel::TextEmbedding3Small),
"qwen3_embedding_8b" | "qwen_embedding" => Ok(EmbeddingModel::Qwen3Embedding8B),
_ => Err(Box::new(EvalAltResult::ErrorRuntime(
format!(
"Unknown embedding model: {}. Available: text_embedding_3_small, qwen3_embedding_8b",
model
)
.into(),
rhai::Position::NONE,
))),
}
}
fn parse_transcription_model(model: &str) -> Result<TranscriptionModel, Box<EvalAltResult>> {
match model.to_lowercase().replace("-", "_").as_str() {
"whisper_large_v3_turbo" | "whisper_turbo" => Ok(TranscriptionModel::WhisperLargeV3Turbo),
"whisper_large_v3" | "whisper" => Ok(TranscriptionModel::WhisperLargeV3),
_ => Err(Box::new(EvalAltResult::ErrorRuntime(
format!(
"Unknown transcription model: {}. Available: whisper_large_v3_turbo, whisper_large_v3",
model
)
.into(),
rhai::Position::NONE,
))),
}
}
fn get_client() -> Result<AiClient, Box<EvalAltResult>> {
let client = AiClient::from_env();
if !client.has_providers() {
return Err(Box::new(EvalAltResult::ErrorRuntime(
"No AI providers configured. Set GROQ_API_KEY, OPENROUTER_API_KEY, or SAMBANOVA_API_KEY environment variable.".into(),
rhai::Position::NONE,
)));
}
Ok(client)
}
fn rhai_ai_chat(prompt: &str) -> Result<String, Box<EvalAltResult>> {
let client = get_client()?;
let messages = vec![Message::user(prompt)];
let response = client
.chat(Model::default_general(), messages)
.map_err(|e| to_rhai_error("Chat failed", e))?;
response
.content()
.map(|s| s.to_string())
.ok_or_else(|| to_rhai_error("No content in response", "empty response"))
}
fn rhai_ai_chat_with_model(model: &str, prompt: &str) -> Result<String, Box<EvalAltResult>> {
let client = get_client()?;
let model = parse_model(model)?;
let messages = vec![Message::user(prompt)];
let response = client
.chat(model, messages)
.map_err(|e| to_rhai_error("Chat failed", e))?;
response
.content()
.map(|s| s.to_string())
.ok_or_else(|| to_rhai_error("No content in response", "empty response"))
}
fn rhai_ai_chat_with_system(
model: &str,
system: &str,
prompt: &str,
) -> Result<String, Box<EvalAltResult>> {
let client = get_client()?;
let model = parse_model(model)?;
let messages = vec![Message::system(system), Message::user(prompt)];
let response = client
.chat(model, messages)
.map_err(|e| to_rhai_error("Chat failed", e))?;
response
.content()
.map(|s| s.to_string())
.ok_or_else(|| to_rhai_error("No content in response", "empty response"))
}
fn rhai_ai_embed(text: &str) -> Result<Array, Box<EvalAltResult>> {
let client = get_client()?;
let response = client
.embed(EmbeddingModel::default(), text)
.map_err(|e| to_rhai_error("Embedding failed", e))?;
let embedding = response
.embedding()
.ok_or_else(|| to_rhai_error("No embedding in response", "empty response"))?;
Ok(embedding
.iter()
.map(|&f| rhai::Dynamic::from(f as f64))
.collect())
}
fn rhai_ai_embed_with_model(model: &str, text: &str) -> Result<Array, Box<EvalAltResult>> {
let client = get_client()?;
let model = parse_embedding_model(model)?;
let response = client
.embed(model, text)
.map_err(|e| to_rhai_error("Embedding failed", e))?;
let embedding = response
.embedding()
.ok_or_else(|| to_rhai_error("No embedding in response", "empty response"))?;
Ok(embedding
.iter()
.map(|&f| rhai::Dynamic::from(f as f64))
.collect())
}
fn rhai_ai_embed_batch(texts: Array) -> Result<Array, Box<EvalAltResult>> {
let client = get_client()?;
let texts: Vec<String> = texts
.into_iter()
.map(|v| {
v.into_string()
.map_err(|_| to_rhai_error("Invalid text", "expected string"))
})
.collect::<Result<Vec<_>, _>>()?;
let response = client
.embed_batch(EmbeddingModel::default(), texts)
.map_err(|e| to_rhai_error("Batch embedding failed", e))?;
let embeddings: Array = response
.embeddings()
.iter()
.map(|emb| {
rhai::Dynamic::from(
emb.iter()
.map(|&f| rhai::Dynamic::from(f as f64))
.collect::<Array>(),
)
})
.collect();
Ok(embeddings)
}
fn rhai_ai_transcribe(file_path: &str) -> Result<String, Box<EvalAltResult>> {
let client = get_client()?;
let path = Path::new(file_path);
if !path.exists() {
return Err(Box::new(EvalAltResult::ErrorRuntime(
format!("Audio file not found: {}", file_path).into(),
rhai::Position::NONE,
)));
}
let response = client
.transcribe_file(TranscriptionModel::default(), path)
.map_err(|e| to_rhai_error("Transcription failed", e))?;
Ok(response.text)
}
fn rhai_ai_transcribe_with_model(
model: &str,
file_path: &str,
) -> Result<String, Box<EvalAltResult>> {
let client = get_client()?;
let model = parse_transcription_model(model)?;
let path = Path::new(file_path);
if !path.exists() {
return Err(Box::new(EvalAltResult::ErrorRuntime(
format!("Audio file not found: {}", file_path).into(),
rhai::Position::NONE,
)));
}
let response = client
.transcribe_file(model, path)
.map_err(|e| to_rhai_error("Transcription failed", e))?;
Ok(response.text)
}
fn rhai_ai_transcribe_with_options(
model: &str,
file_path: &str,
language: &str,
) -> Result<String, Box<EvalAltResult>> {
let client = get_client()?;
let model = parse_transcription_model(model)?;
let path = Path::new(file_path);
if !path.exists() {
return Err(Box::new(EvalAltResult::ErrorRuntime(
format!("Audio file not found: {}", file_path).into(),
rhai::Position::NONE,
)));
}
let options = TranscriptionOptions::new().with_language(language);
let response = client
.transcribe_file_with_options(model, path, options)
.map_err(|e| to_rhai_error("Transcription failed", e))?;
Ok(response.text)
}
fn rhai_ai_models() -> Array {
Model::all()
.iter()
.map(|m| rhai::Dynamic::from(m.name().to_string()))
.collect()
}
fn rhai_ai_embedding_models() -> Array {
EmbeddingModel::all()
.iter()
.map(|m| rhai::Dynamic::from(m.name().to_string()))
.collect()
}
fn rhai_ai_transcription_models() -> Array {
TranscriptionModel::all()
.iter()
.map(|m| rhai::Dynamic::from(m.name().to_string()))
.collect()
}
pub fn register(engine: &mut Engine) -> Result<(), Box<EvalAltResult>> {
engine.register_fn("ai_chat", rhai_ai_chat);
engine.register_fn("ai_chat_with_model", rhai_ai_chat_with_model);
engine.register_fn("ai_chat_with_system", rhai_ai_chat_with_system);
engine.register_fn("ai_embed", rhai_ai_embed);
engine.register_fn("ai_embed_with_model", rhai_ai_embed_with_model);
engine.register_fn("ai_embed_batch", rhai_ai_embed_batch);
engine.register_fn("ai_transcribe", rhai_ai_transcribe);
engine.register_fn("ai_transcribe_with_model", rhai_ai_transcribe_with_model);
engine.register_fn(
"ai_transcribe_with_options",
rhai_ai_transcribe_with_options,
);
engine.register_fn("ai_models", rhai_ai_models);
engine.register_fn("ai_embedding_models", rhai_ai_embedding_models);
engine.register_fn("ai_transcription_models", rhai_ai_transcription_models);
Ok(())
}
pub fn register_ai_module(engine: &mut Engine) -> Result<(), Box<EvalAltResult>> {
register(engine)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_model() {
assert!(parse_model("llama3_3_70b").is_ok());
assert!(parse_model("deepseek_v3").is_ok());
assert!(parse_model("invalid_model").is_err());
}
#[test]
fn test_parse_embedding_model() {
assert!(parse_embedding_model("text_embedding_3_small").is_ok());
assert!(parse_embedding_model("qwen3_embedding_8b").is_ok());
assert!(parse_embedding_model("invalid").is_err());
}
#[test]
fn test_parse_transcription_model() {
assert!(parse_transcription_model("whisper_large_v3_turbo").is_ok());
assert!(parse_transcription_model("whisper_large_v3").is_ok());
assert!(parse_transcription_model("invalid").is_err());
}
#[test]
fn test_register() {
let mut engine = Engine::new();
assert!(register(&mut engine).is_ok());
}
#[test]
fn test_ai_models_function() {
let mut engine = Engine::new();
register(&mut engine).unwrap();
let result = engine.eval::<Array>("ai_models()").unwrap();
assert!(!result.is_empty());
}
#[test]
fn test_ai_embedding_models_function() {
let mut engine = Engine::new();
register(&mut engine).unwrap();
let result = engine.eval::<Array>("ai_embedding_models()").unwrap();
assert!(!result.is_empty());
}
#[test]
fn test_ai_transcription_models_function() {
let mut engine = Engine::new();
register(&mut engine).unwrap();
let result = engine.eval::<Array>("ai_transcription_models()").unwrap();
assert!(!result.is_empty());
}
}