ferrox-solver 0.3.12

Iron-forged OR-Tools and HiGHS solvers as Converge Suggestors
Documentation
use async_trait::async_trait;
use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
use ferrox_ortools_sys::OrtoolsStatus;
use ferrox_ortools_sys::safe::CpModel;
use std::collections::HashMap;
use tracing::warn;

use super::problem::{ConstraintKind, CpSatPlan, CpSatRequest};

const REQUEST_PREFIX: &str = "cpsat-request:";
const PLAN_PREFIX: &str = "cpsat-plan:";

pub struct CpSatSuggestor;

#[async_trait]
impl Suggestor for CpSatSuggestor {
    fn name(&self) -> &'static str {
        "CpSatSuggestor"
    }

    fn dependencies(&self) -> &[ContextKey] {
        &[ContextKey::Seeds]
    }

    fn complexity_hint(&self) -> Option<&'static str> {
        Some("NP-hard in general; CP-SAT DPLL+propagation+LNS; practical for n≤500 vars")
    }

    fn accepts(&self, ctx: &dyn Context) -> bool {
        ctx.get(ContextKey::Seeds)
            .iter()
            .any(|f| f.id.starts_with(REQUEST_PREFIX) && !plan_exists(ctx, request_id(&f.id)))
    }

    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
        let mut proposals = Vec::new();

        for fact in ctx
            .get(ContextKey::Seeds)
            .iter()
            .filter(|f| f.id.starts_with(REQUEST_PREFIX))
        {
            let rid = request_id(&fact.id);
            if plan_exists(ctx, rid) {
                continue;
            }

            match serde_json::from_str::<CpSatRequest>(&fact.content) {
                Ok(req) => {
                    let plan = solve_cp(&req);
                    let confidence = match plan.status.as_str() {
                        "optimal" => 1.0,
                        "feasible" => 0.7,
                        _ => 0.0,
                    };
                    proposals.push(
                        ProposedFact::new(
                            ContextKey::Strategies,
                            format!("{PLAN_PREFIX}{}", plan.request_id),
                            serde_json::to_string(&plan).unwrap_or_default(),
                            self.name(),
                        )
                        .with_confidence(confidence),
                    );
                }
                Err(e) => {
                    warn!(id = %fact.id, error = %e, "malformed cpsat-request");
                }
            }
        }

        if proposals.is_empty() {
            AgentEffect::empty()
        } else {
            AgentEffect::with_proposals(proposals)
        }
    }
}

fn request_id(fact_id: &str) -> &str {
    fact_id.trim_start_matches(REQUEST_PREFIX)
}

fn plan_exists(ctx: &dyn Context, request_id: &str) -> bool {
    let plan_id = format!("{PLAN_PREFIX}{request_id}");
    ctx.get(ContextKey::Strategies)
        .iter()
        .any(|f| f.id == plan_id)
}

#[allow(clippy::too_many_lines)]
pub fn solve_cp(req: &CpSatRequest) -> CpSatPlan {
    let mut model = CpModel::new();
    let mut name_to_idx: HashMap<String, i32> = HashMap::new();
    let mut bool_name_to_idx: HashMap<String, i32> = HashMap::new();
    let mut interval_name_to_idx: HashMap<String, i32> = HashMap::new();

    for var in &req.variables {
        let idx = if var.is_bool {
            let i = model.new_bool_var(&var.name);
            bool_name_to_idx.insert(var.name.clone(), i);
            i
        } else {
            model.new_int_var(var.lb, var.ub, &var.name)
        };
        name_to_idx.insert(var.name.clone(), idx);
    }

    for ivd in &req.interval_vars {
        if let (Some(&s), Some(&e)) =
            (name_to_idx.get(&ivd.start_var), name_to_idx.get(&ivd.end_var))
        {
            let idx = model.new_fixed_interval_var(s, ivd.duration, e, &ivd.name);
            interval_name_to_idx.insert(ivd.name.clone(), idx);
        } else {
            warn!(name = %ivd.name, "interval_var references unknown start/end variable");
        }
    }

    for ovd in &req.optional_interval_vars {
        match (
            name_to_idx.get(&ovd.start_var),
            name_to_idx.get(&ovd.end_var),
            bool_name_to_idx.get(&ovd.lit_var),
        ) {
            (Some(&s), Some(&e), Some(&lit)) => {
                let idx =
                    model.new_optional_interval_var(s, ovd.duration, e, lit, &ovd.name);
                interval_name_to_idx.insert(ovd.name.clone(), idx);
            }
            _ => {
                warn!(name = %ovd.name, "optional_interval_var references unknown variable(s)");
            }
        }
    }

    for constraint in &req.constraints {
        match constraint {
            ConstraintKind::LinearLe { terms, rhs } => {
                let (vars, coeffs) = terms_to_vecs(terms, &name_to_idx);
                model.add_linear_le(&vars, &coeffs, *rhs);
            }
            ConstraintKind::LinearGe { terms, rhs } => {
                let (vars, coeffs) = terms_to_vecs(terms, &name_to_idx);
                model.add_linear_ge(&vars, &coeffs, *rhs);
            }
            ConstraintKind::LinearEq { terms, rhs } => {
                let (vars, coeffs) = terms_to_vecs(terms, &name_to_idx);
                model.add_linear_eq(&vars, &coeffs, *rhs);
            }
            ConstraintKind::AllDifferent { vars } => {
                let idxs: Vec<i32> = vars
                    .iter()
                    .filter_map(|v| name_to_idx.get(v).copied())
                    .collect();
                model.add_all_different(&idxs);
            }
            ConstraintKind::NoOverlap { intervals } => {
                let idxs: Vec<i32> = intervals
                    .iter()
                    .filter_map(|n| interval_name_to_idx.get(n).copied())
                    .collect();
                model.add_no_overlap(&idxs);
            }
        }
    }

    if let Some(obj_terms) = &req.objective_terms {
        let (vars, coeffs) = terms_to_vecs(obj_terms, &name_to_idx);
        if req.minimize {
            model.minimize(&vars, &coeffs);
        } else {
            model.maximize(&vars, &coeffs);
        }
    }

    let time_limit = req.time_limit_seconds.unwrap_or(60.0);
    let solution = model.solve(time_limit);

    let status = match solution.status() {
        OrtoolsStatus::Optimal => "optimal",
        OrtoolsStatus::Feasible => "feasible",
        OrtoolsStatus::Infeasible => "infeasible",
        OrtoolsStatus::Unbounded => "unbounded",
        _ => "error",
    };

    let assignments = if solution.status().is_success() {
        req.variables
            .iter()
            .filter_map(|v| {
                name_to_idx
                    .get(&v.name)
                    .map(|&idx| (v.name.clone(), solution.value(idx)))
            })
            .collect()
    } else {
        vec![]
    };

    let objective_value = if solution.status().is_success() && req.objective_terms.is_some() {
        Some(solution.objective_value())
    } else {
        None
    };

    CpSatPlan {
        request_id: req.id.clone(),
        status: status.to_string(),
        assignments,
        objective_value,
        wall_time_seconds: solution.wall_time(),
        solver: "cp-sat-v9.15".to_string(),
    }
}

fn terms_to_vecs(
    terms: &[crate::cp::problem::CpTerm],
    name_to_idx: &HashMap<String, i32>,
) -> (Vec<i32>, Vec<i64>) {
    terms
        .iter()
        .filter_map(|t| name_to_idx.get(&t.var).map(|&idx| (idx, t.coeff)))
        .unzip()
}