rho-coding-agent 0.8.0

A lightweight agent harness inspired by Pi
use std::{fs, path::PathBuf, time::Duration};

use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
pub struct ModelMetadata {
    pub advertised_context_window: Option<u64>,
    pub effective_context_window: Option<u64>,
    pub usable_context_window: Option<u64>,
    pub long_context_threshold: Option<u64>,
    pub max_output_tokens: Option<u64>,
    pub cost_default: Option<ModelCost>,
    pub cost_long_context: Option<ModelCost>,
}

impl ModelMetadata {
    pub fn display_context_window(&self) -> Option<u64> {
        self.usable_context_window
            .or(self.effective_context_window)
            .or(self.advertised_context_window)
    }

    pub fn cost_for_input_tokens(&self, input_tokens: u64) -> Option<ModelCost> {
        if self
            .long_context_threshold
            .is_some_and(|threshold| input_tokens > threshold)
        {
            self.cost_long_context.or(self.cost_default)
        } else {
            self.cost_default
        }
    }
}

#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
pub struct ModelCost {
    pub input_micros_per_m: Option<u64>,
    pub output_micros_per_m: Option<u64>,
    pub cache_read_micros_per_m: Option<u64>,
    pub cache_write_micros_per_m: Option<u64>,
}

pub fn cached_model_metadata(provider: &str, model: &str) -> Option<ModelMetadata> {
    cached_upstream_model_metadata(provider, model)
        .map(|metadata| apply_overrides(provider, model, metadata))
}

pub async fn fetch_model_metadata(provider: &str, model: &str) -> Option<ModelMetadata> {
    if let Some(metadata) = cached_model_metadata(provider, model) {
        return Some(metadata);
    }

    if let Some(metadata) = read_cached_api()
        .as_ref()
        .and_then(|api| upstream_metadata_from_api(api, provider, model))
    {
        write_cached_upstream_model_metadata(provider, model, &metadata);
        return Some(apply_overrides(provider, model, metadata));
    }

    let Some(response) = fetch_models_dev_api().await else {
        return override_metadata(provider, model);
    };
    write_cached_api(&response);
    if let Some(metadata) = upstream_metadata_from_api(&response, provider, model) {
        write_cached_upstream_model_metadata(provider, model, &metadata);
        return Some(apply_overrides(provider, model, metadata));
    }
    override_metadata(provider, model)
}

fn upstream_metadata_from_api(api: &Value, provider: &str, model: &str) -> Option<ModelMetadata> {
    model_metadata_from_api(api, upstream_provider(provider), model)
}

fn apply_overrides(provider: &str, model: &str, metadata: ModelMetadata) -> ModelMetadata {
    let metadata = apply_builtin_overrides(provider, model, metadata);
    apply_local_overrides(provider, model, metadata)
}

fn override_metadata(provider: &str, model: &str) -> Option<ModelMetadata> {
    let metadata = apply_overrides(provider, model, ModelMetadata::default());
    metadata_has_values(&metadata).then_some(metadata)
}

fn metadata_has_values(metadata: &ModelMetadata) -> bool {
    metadata.advertised_context_window.is_some()
        || metadata.effective_context_window.is_some()
        || metadata.usable_context_window.is_some()
        || metadata.long_context_threshold.is_some()
        || metadata.max_output_tokens.is_some()
        || metadata.cost_default.is_some()
        || metadata.cost_long_context.is_some()
}

async fn fetch_models_dev_api() -> Option<Value> {
    reqwest::Client::builder()
        .timeout(Duration::from_secs(5))
        .build()
        .ok()?
        .get("https://models.dev/api.json")
        .header("User-Agent", concat!("rho/", env!("CARGO_PKG_VERSION")))
        .send()
        .await
        .ok()?
        .error_for_status()
        .ok()?
        .json::<Value>()
        .await
        .ok()
}

fn read_cached_api() -> Option<Value> {
    let contents = fs::read_to_string(models_dev_cache_path()).ok()?;
    serde_json::from_str(&contents).ok()
}

fn write_cached_api(value: &Value) {
    let path = models_dev_cache_path();
    if let Some(parent) = path.parent() {
        let _ = fs::create_dir_all(parent);
    }
    if let Ok(contents) = serde_json::to_string(value) {
        let _ = fs::write(path, contents);
    }
}

fn cached_upstream_model_metadata(provider: &str, model: &str) -> Option<ModelMetadata> {
    let upstream_provider = upstream_provider(provider);
    let connection = open_models_dev_cache().ok()?;
    let contents: String = connection
        .query_row(
            "select metadata_json from model_metadata where provider = ?1 and model = ?2",
            params![upstream_provider, model],
            |row| row.get(0),
        )
        .ok()?;
    serde_json::from_str(&contents).ok()
}

fn write_cached_upstream_model_metadata(provider: &str, model: &str, metadata: &ModelMetadata) {
    let upstream_provider = upstream_provider(provider);
    let Ok(connection) = open_models_dev_cache() else {
        return;
    };
    let Ok(contents) = serde_json::to_string(metadata) else {
        return;
    };
    let _ = connection.execute(
        "insert into model_metadata (provider, model, metadata_json, updated_at)
         values (?1, ?2, ?3, strftime('%s', 'now'))
         on conflict(provider, model) do update set
           metadata_json = excluded.metadata_json,
           updated_at = excluded.updated_at",
        params![upstream_provider, model, contents],
    );
}

fn open_models_dev_cache() -> rusqlite::Result<Connection> {
    let path = models_dev_sqlite_path();
    if let Some(parent) = path.parent() {
        let _ = fs::create_dir_all(parent);
    }
    let connection = Connection::open(path)?;
    connection.execute_batch(
        "create table if not exists model_metadata (
            provider text not null,
            model text not null,
            metadata_json text not null,
            updated_at integer not null,
            primary key (provider, model)
        );",
    )?;
    Ok(connection)
}

fn upstream_provider(provider: &str) -> &str {
    match provider {
        "openai" | "openai-codex" => "openai",
        other => other,
    }
}

fn models_dev_sqlite_path() -> PathBuf {
    cache_dir().join("models.dev/models-dev-metadata.sqlite3")
}

fn models_dev_cache_path() -> PathBuf {
    cache_dir().join("models.dev/api.json")
}

fn cache_dir() -> PathBuf {
    if let Some(path) = std::env::var_os("XDG_CACHE_HOME") {
        return PathBuf::from(path).join("rho");
    }
    #[cfg(target_os = "windows")]
    {
        if let Some(path) = std::env::var_os("LOCALAPPDATA") {
            return PathBuf::from(path).join("rho").join("cache");
        }
    }
    #[cfg(target_os = "macos")]
    {
        if let Some(path) = std::env::var_os("HOME") {
            return PathBuf::from(path)
                .join("Library")
                .join("Caches")
                .join("rho");
        }
    }
    if let Some(path) = std::env::var_os("HOME") {
        return PathBuf::from(path).join(".cache").join("rho");
    }
    std::env::temp_dir().join("rho-cache")
}

fn model_metadata_from_api(api: &Value, provider: &str, model: &str) -> Option<ModelMetadata> {
    let model = api.get(provider)?.get("models")?.get(model).or_else(|| {
        api.get(provider)?
            .get("models")?
            .get(model.strip_prefix("openai/")?)
    })?;
    let limit = model.get("limit");
    let cost = model.get("cost");
    Some(ModelMetadata {
        advertised_context_window: limit
            .and_then(|limit| limit.get("context"))
            .and_then(|value| value.as_u64()),
        effective_context_window: limit
            .and_then(|limit| limit.get("input").or_else(|| limit.get("context")))
            .and_then(|value| value.as_u64()),
        usable_context_window: None,
        long_context_threshold: None,
        max_output_tokens: limit
            .and_then(|limit| limit.get("output"))
            .and_then(|value| value.as_u64()),
        cost_default: Some(ModelCost {
            input_micros_per_m: cost
                .and_then(|cost| cost.get("input"))
                .and_then(cost_micros_per_million),
            output_micros_per_m: cost
                .and_then(|cost| cost.get("output"))
                .and_then(cost_micros_per_million),
            cache_read_micros_per_m: cost
                .and_then(|cost| cost.get("cache_read"))
                .and_then(cost_micros_per_million),
            cache_write_micros_per_m: cost
                .and_then(|cost| cost.get("cache_write"))
                .and_then(cost_micros_per_million),
        }),
        cost_long_context: None,
    })
}

fn apply_builtin_overrides(
    provider: &str,
    model: &str,
    mut metadata: ModelMetadata,
) -> ModelMetadata {
    match (provider, model) {
        ("openai", "gpt-5.5") => {
            metadata.advertised_context_window = Some(1_050_000);
            metadata.effective_context_window = Some(272_000);
            metadata.usable_context_window = Some(272_000);
            metadata.long_context_threshold = Some(272_000);
            metadata.max_output_tokens = Some(128_000);
            metadata.cost_default = Some(ModelCost {
                input_micros_per_m: Some(5_000_000),
                output_micros_per_m: Some(30_000_000),
                cache_read_micros_per_m: Some(500_000),
                cache_write_micros_per_m: None,
            });
            metadata.cost_long_context = Some(ModelCost {
                input_micros_per_m: Some(10_000_000),
                output_micros_per_m: Some(45_000_000),
                cache_read_micros_per_m: Some(1_000_000),
                cache_write_micros_per_m: None,
            });
        }
        ("openai-codex", "gpt-5.5") => {
            metadata.advertised_context_window = Some(1_050_000);
            metadata.effective_context_window = Some(400_000);
            metadata.usable_context_window = Some(272_000);
            metadata.long_context_threshold = Some(272_000);
            metadata.max_output_tokens = Some(128_000);
            metadata.cost_default = Some(ModelCost {
                input_micros_per_m: Some(5_000_000),
                output_micros_per_m: Some(30_000_000),
                cache_read_micros_per_m: Some(500_000),
                cache_write_micros_per_m: None,
            });
            metadata.cost_long_context = Some(ModelCost {
                input_micros_per_m: Some(10_000_000),
                output_micros_per_m: Some(45_000_000),
                cache_read_micros_per_m: Some(1_000_000),
                cache_write_micros_per_m: None,
            });
        }
        _ => {}
    }
    metadata
}

fn apply_local_overrides(provider: &str, model: &str, metadata: ModelMetadata) -> ModelMetadata {
    let Some(path) = local_overrides_path() else {
        return metadata;
    };
    let Ok(contents) = fs::read_to_string(path) else {
        return metadata;
    };
    let Ok(value) = contents.parse::<toml::Value>() else {
        return metadata;
    };
    let key = format!("{provider}/{model}");
    let Some(table) = value
        .get("models")
        .and_then(|models| models.get(&key))
        .and_then(|value| value.as_table())
    else {
        return metadata;
    };

    merge_toml_override(metadata, table)
}

fn local_overrides_path() -> Option<PathBuf> {
    if let Some(path) = std::env::var_os("RHO_MODELS_PATH") {
        return Some(path.into());
    }
    Some(PathBuf::from(std::env::var_os("HOME")?).join(".rho/models.toml"))
}

fn merge_toml_override(
    mut metadata: ModelMetadata,
    table: &toml::map::Map<String, toml::Value>,
) -> ModelMetadata {
    metadata.advertised_context_window =
        toml_u64(table, "advertised_context_window").or(metadata.advertised_context_window);
    metadata.effective_context_window =
        toml_u64(table, "effective_context_window").or(metadata.effective_context_window);
    metadata.usable_context_window =
        toml_u64(table, "usable_context_window").or(metadata.usable_context_window);
    metadata.long_context_threshold =
        toml_u64(table, "long_context_threshold").or(metadata.long_context_threshold);
    metadata.max_output_tokens =
        toml_u64(table, "max_output_tokens").or(metadata.max_output_tokens);
    metadata.cost_default = toml_cost(table, "cost_default").or(metadata.cost_default);
    metadata.cost_long_context =
        toml_cost(table, "cost_long_context").or(metadata.cost_long_context);
    metadata
}

fn toml_u64(table: &toml::map::Map<String, toml::Value>, key: &str) -> Option<u64> {
    table
        .get(key)
        .and_then(|value| value.as_integer())
        .and_then(|value| u64::try_from(value).ok())
}

fn toml_cost(table: &toml::map::Map<String, toml::Value>, key: &str) -> Option<ModelCost> {
    let table = table.get(key)?.as_table()?;
    Some(ModelCost {
        input_micros_per_m: toml_cost_value(table, "input"),
        output_micros_per_m: toml_cost_value(table, "output"),
        cache_read_micros_per_m: toml_cost_value(table, "cache_read"),
        cache_write_micros_per_m: toml_cost_value(table, "cache_write"),
    })
}

fn toml_cost_value(table: &toml::map::Map<String, toml::Value>, key: &str) -> Option<u64> {
    let dollars = table.get(key).and_then(|value| {
        value
            .as_float()
            .or_else(|| value.as_integer().map(|v| v as f64))
    })?;
    dollars
        .is_finite()
        .then(|| (dollars.max(0.0) * 1_000_000.0).round() as u64)
}

fn cost_micros_per_million(value: &Value) -> Option<u64> {
    let dollars = value.as_f64().or_else(|| {
        value
            .as_str()?
            .trim_start_matches('$')
            .replace(',', "")
            .parse()
            .ok()
    })?;
    dollars
        .is_finite()
        .then(|| (dollars.max(0.0) * 1_000_000.0).round() as u64)
}