use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use thiserror::Error;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphCodeInput {
pub code: String,
pub position_ids: Option<Vec<i64>>,
pub dfg_mask: Option<Vec<Vec<i64>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding {
pub index: usize,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub embeddings: Vec<Embedding>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankResult {
pub index: usize,
pub score: f32,
pub document: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankResponse {
pub results: Vec<RerankResult>,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub enum Device {
#[default]
Auto,
Gpu(usize),
Cpu,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientConfig {
pub models_dir: PathBuf,
pub device: Device,
}
impl Default for ClientConfig {
fn default() -> Self {
let models_dir = dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".gllm")
.join("models");
Self {
models_dir,
device: Device::Auto,
}
}
}
impl ClientConfig {
pub fn model_dir(&self) -> &Path {
&self.models_dir
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Failed to download model: {0}")]
DownloadError(String),
#[error("Failed to load model: {0}")]
LoadError(String),
#[error("Inference error: {0}")]
InferenceError(String),
#[error("Out of memory: {0}")]
Oom(String),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("IO error: {0}")]
Io(std::io::Error),
#[error("Internal error: {0}")]
InternalError(String),
}
impl Error {
pub fn is_oom(&self) -> bool {
match self {
Error::Oom(_) => true,
Error::InferenceError(msg) | Error::InternalError(msg) => {
let msg_lower = msg.to_lowercase();
msg_lower.contains("oom")
|| msg_lower.contains("out of memory")
|| msg_lower.contains("allocate")
|| msg_lower.contains("memory")
|| msg_lower.contains("buffer")
}
_ => false,
}
}
}
impl Clone for Error {
fn clone(&self) -> Self {
match self {
Error::ModelNotFound(s) => Error::ModelNotFound(s.clone()),
Error::DownloadError(s) => Error::DownloadError(s.clone()),
Error::LoadError(s) => Error::LoadError(s.clone()),
Error::InferenceError(s) => Error::InferenceError(s.clone()),
Error::Oom(s) => Error::Oom(s.clone()),
Error::InvalidConfig(s) => Error::InvalidConfig(s.clone()),
Error::Io(io_err) => Error::Io(std::io::Error::new(io_err.kind(), io_err.to_string())),
Error::InternalError(s) => Error::InternalError(s.clone()),
}
}
}
impl From<String> for Error {
fn from(s: String) -> Self {
Error::LoadError(s)
}
}
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Self {
Error::Io(err)
}
}
pub type Result<T> = std::result::Result<T, Error>;