use std::collections::BTreeMap;
use surfpool_types::{OverrideTemplate, YamlOverrideTemplateCollection};
pub const PYTH_V2_IDL_CONTENT: &str = include_str!("./protocols/pyth/v2/idl.json");
pub const PYTH_V2_OVERRIDES_CONTENT: &str = include_str!("./protocols/pyth/v2/overrides.yaml");
pub const JUPITER_V6_IDL_CONTENT: &str = include_str!("./protocols/jupiter/v6/idl.json");
pub const JUPITER_V6_OVERRIDES_CONTENT: &str =
include_str!("./protocols/jupiter/v6/overrides.yaml");
pub const SWITCHBOARD_ON_DEMAND_IDL_CONTENT: &str =
include_str!("./protocols/switchboard/on-demand/idl.json");
pub const SWITCHBOARD_ON_DEMAND_OVERRIDES_CONTENT: &str =
include_str!("./protocols/switchboard/on-demand/overrides.yaml");
pub const RAYDIUM_CLMM_IDL_CONTENT: &str = include_str!("./protocols/raydium/v3/idl.json");
pub const RAYDIUM_CLMM_OVERRIDES_CONTENT: &str =
include_str!("./protocols/raydium/v3/overrides.yaml");
pub const METEORA_DLMM_IDL_CONTENT: &str = include_str!("./protocols/meteora/dlmm/v1/idl.json");
pub const METEORA_DLMM_OVERRIDES_CONTENT: &str =
include_str!("./protocols/meteora/dlmm/v1/overrides.yaml");
pub const KAMINO_V1_IDL_CONTENT: &str = include_str!("./protocols/kamino/v1/idl.json");
pub const KAMINO_V1_OVERRIDES_CONTENT: &str = include_str!("./protocols/kamino/v1/overrides.yaml");
pub const DRIFT_V2_IDL_CONTENT: &str = include_str!("./protocols/drift/v2/idl.json");
pub const DRIFT_V2_OVERRIDES_CONTENT: &str = include_str!("./protocols/drift/v2/overrides.yaml");
pub const WHIRLPOOL_IDL_CONTENT: &str = include_str!("./protocols/whirlpool/idl.json");
pub const WHIRLPOOL_OVERRIDES_CONTENT: &str = include_str!("./protocols/whirlpool/overrides.yaml");
#[derive(Clone, Debug, Default)]
pub struct TemplateRegistry {
pub templates: BTreeMap<String, OverrideTemplate>,
}
impl TemplateRegistry {
pub fn new() -> Self {
let mut default = Self::default();
default.load_pyth_overrides();
default.load_jupiter_overrides();
default.load_raydium_overrides();
default.load_switchboard_on_demand_overrides();
default.load_meteora_overrides();
default.load_kamino_overrides();
default.load_drift_overrides();
default.load_whirlpool_overrides();
default
}
pub fn load_pyth_overrides(&mut self) {
self.load_protocol_overrides(PYTH_V2_IDL_CONTENT, PYTH_V2_OVERRIDES_CONTENT, "pyth");
}
pub fn load_jupiter_overrides(&mut self) {
self.load_protocol_overrides(
JUPITER_V6_IDL_CONTENT,
JUPITER_V6_OVERRIDES_CONTENT,
"jupiter",
);
}
pub fn load_switchboard_on_demand_overrides(&mut self) {
self.load_protocol_overrides(
SWITCHBOARD_ON_DEMAND_IDL_CONTENT,
SWITCHBOARD_ON_DEMAND_OVERRIDES_CONTENT,
"switchboard-on-demand",
);
}
pub fn load_meteora_overrides(&mut self) {
self.load_protocol_overrides(
METEORA_DLMM_IDL_CONTENT,
METEORA_DLMM_OVERRIDES_CONTENT,
"meteora",
);
}
pub fn load_raydium_overrides(&mut self) {
self.load_protocol_overrides(
RAYDIUM_CLMM_IDL_CONTENT,
RAYDIUM_CLMM_OVERRIDES_CONTENT,
"raydium",
);
}
pub fn load_kamino_overrides(&mut self) {
self.load_protocol_overrides(KAMINO_V1_IDL_CONTENT, KAMINO_V1_OVERRIDES_CONTENT, "kamino");
}
pub fn load_drift_overrides(&mut self) {
self.load_protocol_overrides(DRIFT_V2_IDL_CONTENT, DRIFT_V2_OVERRIDES_CONTENT, "drift");
}
pub fn load_whirlpool_overrides(&mut self) {
self.load_protocol_overrides(
WHIRLPOOL_IDL_CONTENT,
WHIRLPOOL_OVERRIDES_CONTENT,
"whirlpool",
);
}
fn load_protocol_overrides(
&mut self,
idl_content: &str,
overrides_content: &str,
protocol_name: &str,
) {
let idl = match serde_json::from_str(idl_content) {
Ok(idl) => idl,
Err(e) => panic!("unable to load {} idl: {}", protocol_name, e),
};
let Ok(collection) =
serde_yaml::from_str::<YamlOverrideTemplateCollection>(overrides_content)
else {
panic!("unable to load {} overrides", protocol_name);
};
let templates = collection.to_override_templates(idl);
for template in templates {
let template_id = template.id.clone();
self.templates.insert(template_id.clone(), template);
}
}
pub fn get(&self, template_id: &str) -> Option<&OverrideTemplate> {
self.templates.get(template_id)
}
pub fn all(&self) -> Vec<&OverrideTemplate> {
self.templates.values().collect()
}
pub fn by_protocol(&self, protocol: &str) -> Vec<&OverrideTemplate> {
self.templates
.values()
.filter(|t| t.protocol.eq_ignore_ascii_case(protocol))
.collect()
}
pub fn by_tags(&self, tags: &[String]) -> Vec<&OverrideTemplate> {
self.templates
.values()
.filter(|t| t.tags.iter().any(|tag| tags.contains(tag)))
.collect()
}
pub fn count(&self) -> usize {
self.templates.len()
}
pub fn contains(&self, template_id: &str) -> bool {
self.templates.contains_key(template_id)
}
pub fn list_ids(&self) -> Vec<String> {
self.templates.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_loads_all_protocols() {
let registry = TemplateRegistry::new();
assert_eq!(
registry.count(),
22,
"Registry should load 22 templates total"
);
assert!(registry.contains("pyth-sol-usd-v2"));
assert!(registry.contains("pyth-btc-usd-v2"));
assert!(registry.contains("pyth-eth-btc-v2"));
assert!(registry.contains("pyth-eth-usd-v2"));
assert!(registry.contains("jupiter-token-ledger-override"));
assert!(registry.contains("raydium-clmm-sol-usdc"));
assert!(registry.contains("raydium-clmm-btc-usdc"));
assert!(registry.contains("raydium-clmm-eth-usdc"));
assert!(registry.contains("switchboard-quote-override"));
assert!(registry.contains("meteora-dlmm-sol-usdc"));
assert!(registry.contains("meteora-dlmm-usdt-sol"));
assert!(registry.contains("kamino-reserve-state"));
assert!(registry.contains("kamino-reserve-config"));
assert!(registry.contains("kamino-obligation-health"));
assert!(registry.contains("drift-perp-market"));
assert!(registry.contains("drift-spot-market"));
assert!(registry.contains("drift-user-state"));
assert!(registry.contains("drift-global-state"));
assert!(registry.contains("whirlpool-sol-usdc"));
assert!(registry.contains("whirlpool-sol-usdt"));
assert!(registry.contains("whirlpool-msol-sol"));
assert!(registry.contains("whirlpool-orca-usdc"));
}
#[test]
fn test_jupiter_template_loads_correctly() {
let registry = TemplateRegistry::new();
let jupiter_template = registry
.get("jupiter-token-ledger-override")
.expect("Jupiter template should exist");
assert_eq!(jupiter_template.protocol, "Jupiter");
assert_eq!(jupiter_template.account_type, "TokenLedger");
assert_eq!(jupiter_template.name, "Override Jupiter Token Ledger");
assert_eq!(jupiter_template.properties.len(), 2);
assert!(
jupiter_template
.properties
.contains(&"tokenAccount".to_string())
);
assert!(jupiter_template.properties.contains(&"amount".to_string()));
assert!(jupiter_template.tags.contains(&"dex".to_string()));
assert!(jupiter_template.tags.contains(&"aggregator".to_string()));
assert!(jupiter_template.tags.contains(&"swap".to_string()));
assert!(jupiter_template.tags.contains(&"defi".to_string()));
}
#[test]
fn test_filter_by_protocol() {
let registry = TemplateRegistry::new();
let pyth_templates = registry.by_protocol("Pyth");
assert_eq!(pyth_templates.len(), 4, "Should have 4 Pyth templates");
let jupiter_templates = registry.by_protocol("Jupiter");
assert_eq!(jupiter_templates.len(), 1, "Should have 1 Jupiter template");
let raydium_templates = registry.by_protocol("Raydium");
assert_eq!(
raydium_templates.len(),
3,
"Should have 3 Raydium templates"
);
let kamino_templates = registry.by_protocol("Kamino");
assert_eq!(kamino_templates.len(), 3, "Should have 3 Kamino templates");
let whirlpool_templates = registry.by_protocol("Whirlpool");
assert_eq!(
whirlpool_templates.len(),
4,
"Should have 4 Whirlpool templates"
);
}
#[test]
fn test_filter_by_tags() {
let registry = TemplateRegistry::new();
let oracle_templates = registry.by_tags(&[vec!["oracle".to_string()]].concat());
assert_eq!(
oracle_templates.len(),
5,
"Should find 5 oracle templates (Pyth + Switchboard)"
);
let dex_templates = registry.by_tags(&[vec!["dex".to_string()]].concat());
assert_eq!(
dex_templates.len(),
1,
"Should find 1 dex template (Jupiter)"
);
let aggregator_templates = registry.by_tags(&[vec!["aggregator".to_string()]].concat());
assert_eq!(
aggregator_templates.len(),
1,
"Should find 1 aggregator template (Jupiter)"
);
}
#[test]
fn test_jupiter_idl_has_token_ledger_account() {
let registry = TemplateRegistry::new();
let jupiter_template = registry.get("jupiter-token-ledger-override").unwrap();
let has_token_ledger = jupiter_template
.idl
.accounts
.iter()
.any(|acc| acc.name == "TokenLedger");
assert!(has_token_ledger, "IDL should contain TokenLedger account");
}
#[test]
fn test_list_all_template_ids() {
let registry = TemplateRegistry::new();
let ids = registry.list_ids();
assert!(ids.contains(&"raydium-clmm-sol-usdc".to_string()));
assert!(ids.contains(&"jupiter-token-ledger-override".to_string()));
assert!(ids.contains(&"pyth-sol-usd-v2".to_string()));
assert!(ids.contains(&"meteora-dlmm-sol-usdc".to_string()));
assert!(ids.contains(&"raydium-clmm-sol-usdc".to_string()));
assert!(ids.contains(&"jupiter-token-ledger-override".to_string()));
assert!(ids.contains(&"pyth-sol-usd-v2".to_string()));
assert!(ids.contains(&"kamino-reserve-state".to_string()));
assert!(ids.contains(&"kamino-reserve-config".to_string()));
assert!(ids.contains(&"kamino-obligation-health".to_string()));
assert!(ids.contains(&"drift-perp-market".to_string()));
assert!(ids.contains(&"whirlpool-sol-usdc".to_string()));
assert!(ids.contains(&"whirlpool-sol-usdt".to_string()));
assert!(ids.contains(&"whirlpool-msol-sol".to_string()));
assert!(ids.contains(&"whirlpool-orca-usdc".to_string()));
}
}
#[test]
fn test_switchboard_template_loads_correctly() {
let registry = TemplateRegistry::new();
let switchboard_template = registry
.get("switchboard-quote-override")
.expect("Switchboard template should exist");
assert_eq!(switchboard_template.protocol, "Switchboard");
assert_eq!(switchboard_template.account_type, "SwitchboardQuote");
assert_eq!(
switchboard_template.name,
"Override Switchboard Oracle Quote"
);
assert_eq!(switchboard_template.properties.len(), 3);
assert!(
switchboard_template
.properties
.contains(&"queue".to_string())
);
assert!(
switchboard_template
.properties
.contains(&"slot".to_string())
);
assert!(
switchboard_template
.properties
.contains(&"version".to_string())
);
assert!(switchboard_template.tags.contains(&"oracle".to_string()));
assert!(
switchboard_template
.tags
.contains(&"price-feed".to_string())
);
}
#[test]
fn test_switchboard_idl_has_quote_account() {
let registry = TemplateRegistry::new();
let switchboard_template = registry.get("switchboard-quote-override").unwrap();
let has_quote_account = switchboard_template
.idl
.accounts
.iter()
.any(|acc| acc.name == "SwitchboardQuote");
assert!(
has_quote_account,
"IDL should contain SwitchboardQuote account"
);
}
#[test]
fn test_filter_by_oracle_tag_includes_switchboard() {
let registry = TemplateRegistry::new();
let oracle_templates = registry.by_tags(&[vec!["oracle".to_string()]].concat());
assert!(
oracle_templates.len() >= 5,
"Should find at least 5 oracle templates (Pyth + Switchboard)"
);
}