use crate::device::{get_device, DeviceConfig};
use crate::ecosystem::EcosystemInfo;
use crate::error::{CoreError, Result};
use candle_core::Device;
#[derive(Debug, Clone)]
pub struct RustAIConfig {
pub device: DeviceConfig,
pub verbose: bool,
pub memory_limit: usize,
}
impl Default for RustAIConfig {
fn default() -> Self {
Self {
device: DeviceConfig::from_env(),
verbose: false,
memory_limit: 0,
}
}
}
impl RustAIConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
#[must_use]
pub fn with_memory_limit(mut self, limit: usize) -> Self {
self.memory_limit = limit;
self
}
#[must_use]
pub fn with_cpu(mut self) -> Self {
self.device = self.device.with_force_cpu(true);
self
}
#[must_use]
pub fn with_cuda_device(mut self, ordinal: usize) -> Self {
self.device = self.device.with_cuda_device(ordinal);
self
}
}
pub struct RustAI {
config: RustAIConfig,
device: Device,
ecosystem: EcosystemInfo,
}
impl RustAI {
pub fn new(config: RustAIConfig) -> Result<Self> {
let device = get_device(&config.device)?;
let ecosystem = EcosystemInfo::new();
if config.verbose {
tracing::info!("RustAI initialized");
tracing::info!("Device: {:?}", device);
tracing::info!("Ecosystem crates: {:?}", EcosystemInfo::crate_names());
}
Ok(Self {
config,
device,
ecosystem,
})
}
#[must_use]
pub fn device(&self) -> &Device {
&self.device
}
#[must_use]
pub fn ecosystem(&self) -> &EcosystemInfo {
&self.ecosystem
}
#[must_use]
pub fn config(&self) -> &RustAIConfig {
&self.config
}
#[must_use]
pub fn is_cuda(&self) -> bool {
matches!(self.device, Device::Cuda(_))
}
#[must_use]
pub fn finetune(&self) -> FinetuneBuilder<'_> {
FinetuneBuilder::new(self)
}
#[must_use]
pub fn quantize(&self) -> QuantizeBuilder<'_> {
QuantizeBuilder::new(self)
}
#[must_use]
pub fn vsa(&self) -> VsaBuilder<'_> {
VsaBuilder::new(self)
}
#[must_use]
pub fn train(&self) -> TrainBuilder<'_> {
TrainBuilder::new(self)
}
#[must_use]
pub fn info(&self) -> RustAIInfo {
RustAIInfo {
version: crate::VERSION.to_string(),
device: format!("{:?}", self.device),
ecosystem_crates: EcosystemInfo::crate_names()
.iter()
.map(|s| (*s).to_string())
.collect(),
cuda_available: self.is_cuda(),
memory_limit: self.config.memory_limit,
}
}
}
#[derive(Debug, Clone)]
pub struct RustAIInfo {
pub version: String,
pub device: String,
pub ecosystem_crates: Vec<String>,
pub cuda_available: bool,
pub memory_limit: usize,
}
pub struct FinetuneBuilder<'a> {
#[allow(dead_code)]
ai: &'a RustAI,
model_path: Option<String>,
adapter_type: AdapterType,
rank: usize,
alpha: f32,
dropout: f32,
target_modules: Vec<String>,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum AdapterType {
#[default]
Lora,
Dora,
AdaLora,
}
impl<'a> FinetuneBuilder<'a> {
fn new(ai: &'a RustAI) -> Self {
Self {
ai,
model_path: None,
adapter_type: AdapterType::Lora,
rank: 64,
alpha: 16.0,
dropout: 0.1,
target_modules: vec!["q_proj".into(), "v_proj".into()],
}
}
#[must_use]
pub fn model(mut self, path: impl Into<String>) -> Self {
self.model_path = Some(path.into());
self
}
#[must_use]
pub fn adapter(mut self, adapter: AdapterType) -> Self {
self.adapter_type = adapter;
self
}
#[must_use]
pub fn rank(mut self, rank: usize) -> Self {
self.rank = rank;
self
}
#[must_use]
pub fn alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha;
self
}
#[must_use]
pub fn dropout(mut self, dropout: f32) -> Self {
self.dropout = dropout;
self
}
#[must_use]
pub fn target_modules(mut self, modules: Vec<String>) -> Self {
self.target_modules = modules;
self
}
pub fn build(self) -> Result<FinetuneConfig> {
let model_path = self
.model_path
.ok_or_else(|| CoreError::invalid_config("model path is required for fine-tuning"))?;
Ok(FinetuneConfig {
model_path,
adapter_type: self.adapter_type,
rank: self.rank,
alpha: self.alpha,
dropout: self.dropout,
target_modules: self.target_modules,
})
}
}
#[derive(Debug, Clone)]
pub struct FinetuneConfig {
pub model_path: String,
pub adapter_type: AdapterType,
pub rank: usize,
pub alpha: f32,
pub dropout: f32,
pub target_modules: Vec<String>,
}
pub struct QuantizeBuilder<'a> {
#[allow(dead_code)]
ai: &'a RustAI,
method: QuantizeMethod,
bits: u8,
group_size: usize,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum QuantizeMethod {
#[default]
Nf4,
Fp4,
BitNet,
Int8,
}
impl<'a> QuantizeBuilder<'a> {
fn new(ai: &'a RustAI) -> Self {
Self {
ai,
method: QuantizeMethod::Nf4,
bits: 4,
group_size: 64,
}
}
#[must_use]
pub fn method(mut self, method: QuantizeMethod) -> Self {
self.method = method;
self
}
#[must_use]
pub fn bits(mut self, bits: u8) -> Self {
self.bits = bits;
self
}
#[must_use]
pub fn group_size(mut self, size: usize) -> Self {
self.group_size = size;
self
}
#[must_use]
pub fn build(self) -> QuantizeConfig {
QuantizeConfig {
method: self.method,
bits: self.bits,
group_size: self.group_size,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizeConfig {
pub method: QuantizeMethod,
pub bits: u8,
pub group_size: usize,
}
pub struct VsaBuilder<'a> {
#[allow(dead_code)]
ai: &'a RustAI,
dimension: usize,
}
impl<'a> VsaBuilder<'a> {
fn new(ai: &'a RustAI) -> Self {
Self {
ai,
dimension: 10000,
}
}
#[must_use]
pub fn dimension(mut self, dim: usize) -> Self {
self.dimension = dim;
self
}
#[must_use]
pub fn build(self) -> VsaConfig {
VsaConfig {
dimension: self.dimension,
}
}
}
#[derive(Debug, Clone)]
pub struct VsaConfig {
pub dimension: usize,
}
pub struct TrainBuilder<'a> {
#[allow(dead_code)]
ai: &'a RustAI,
config_path: Option<String>,
}
impl<'a> TrainBuilder<'a> {
fn new(ai: &'a RustAI) -> Self {
Self {
ai,
config_path: None,
}
}
#[must_use]
pub fn config_file(mut self, path: impl Into<String>) -> Self {
self.config_path = Some(path.into());
self
}
pub fn build(self) -> Result<TrainConfig> {
let config_path = self
.config_path
.ok_or_else(|| CoreError::invalid_config("config file path is required"))?;
Ok(TrainConfig { config_path })
}
}
#[derive(Debug, Clone)]
pub struct TrainConfig {
pub config_path: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rustai_config_default() {
let config = RustAIConfig::default();
assert!(!config.verbose);
assert_eq!(config.memory_limit, 0);
}
#[test]
fn test_rustai_config_builder() {
let config = RustAIConfig::new()
.with_verbose(true)
.with_memory_limit(1024 * 1024 * 1024)
.with_cpu();
assert!(config.verbose);
assert_eq!(config.memory_limit, 1024 * 1024 * 1024);
assert!(config.device.force_cpu);
}
#[test]
fn test_rustai_new() {
let config = RustAIConfig::new().with_cpu();
let ai = RustAI::new(config).unwrap();
assert!(!ai.is_cuda());
assert_eq!(EcosystemInfo::crate_names().len(), 8);
}
#[test]
fn test_rustai_info() {
let config = RustAIConfig::new().with_cpu();
let ai = RustAI::new(config).unwrap();
let info = ai.info();
assert!(!info.version.is_empty());
assert!(!info.cuda_available);
assert_eq!(info.ecosystem_crates.len(), 8);
}
#[test]
fn test_finetune_builder() {
let config = RustAIConfig::new().with_cpu();
let ai = RustAI::new(config).unwrap();
let finetune_config = ai
.finetune()
.model("test-model")
.rank(32)
.alpha(8.0)
.build()
.unwrap();
assert_eq!(finetune_config.model_path, "test-model");
assert_eq!(finetune_config.rank, 32);
assert!((finetune_config.alpha - 8.0).abs() < f32::EPSILON);
}
#[test]
fn test_quantize_builder() {
let config = RustAIConfig::new().with_cpu();
let ai = RustAI::new(config).unwrap();
let quant_config = ai
.quantize()
.method(QuantizeMethod::BitNet)
.bits(2)
.group_size(128)
.build();
assert!(matches!(quant_config.method, QuantizeMethod::BitNet));
assert_eq!(quant_config.bits, 2);
assert_eq!(quant_config.group_size, 128);
}
#[test]
fn test_vsa_builder() {
let config = RustAIConfig::new().with_cpu();
let ai = RustAI::new(config).unwrap();
let vsa_config = ai.vsa().dimension(8192).build();
assert_eq!(vsa_config.dimension, 8192);
}
#[test]
fn test_train_builder() {
let config = RustAIConfig::new().with_cpu();
let ai = RustAI::new(config).unwrap();
let train_config = ai.train().config_file("train.yaml").build().unwrap();
assert_eq!(train_config.config_path, "train.yaml");
}
}