use std::collections::HashMap;
use std::path::Path;
use serde::{Deserialize, Serialize};
use crate::review_channel::{ReviewChannel, ReviewChannelError};
use crate::session_channel::{SessionChannel, SessionChannelError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChannelCapabilitySet {
pub supports_review: bool,
pub supports_session: bool,
pub supports_notify: bool,
pub supports_rich_media: bool,
pub supports_threads: bool,
}
impl Default for ChannelCapabilitySet {
fn default() -> Self {
Self {
supports_review: true,
supports_session: true,
supports_notify: true,
supports_rich_media: false,
supports_threads: false,
}
}
}
pub trait ChannelFactory: Send + Sync {
fn channel_type(&self) -> &str;
fn build_review(
&self,
config: &serde_json::Value,
) -> Result<Box<dyn ReviewChannel>, ReviewChannelError>;
fn build_session(
&self,
config: &serde_json::Value,
) -> Result<Box<dyn SessionChannel>, SessionChannelError>;
fn capabilities(&self) -> ChannelCapabilitySet;
}
pub struct ChannelRegistry {
factories: HashMap<String, Box<dyn ChannelFactory>>,
}
impl ChannelRegistry {
pub fn new() -> Self {
Self {
factories: HashMap::new(),
}
}
pub fn register(&mut self, factory: Box<dyn ChannelFactory>) {
let name = factory.channel_type().to_string();
self.factories.insert(name, factory);
}
pub fn get(&self, channel_type: &str) -> Option<&dyn ChannelFactory> {
self.factories.get(channel_type).map(|f| f.as_ref())
}
pub fn channel_types(&self) -> Vec<&str> {
self.factories.keys().map(|k| k.as_str()).collect()
}
pub fn has_channel(&self, channel_type: &str) -> bool {
self.factories.contains_key(channel_type)
}
pub fn len(&self) -> usize {
self.factories.len()
}
pub fn is_empty(&self) -> bool {
self.factories.is_empty()
}
pub fn build_review_from_config(
&self,
route: &ChannelRouteConfig,
) -> Result<Box<dyn ReviewChannel>, ReviewChannelError> {
let factory = self.get(&route.channel_type).ok_or_else(|| {
ReviewChannelError::Other(format!(
"unknown channel type: '{}'. Registered: {:?}",
route.channel_type,
self.channel_types()
))
})?;
factory.build_review(&route.config)
}
pub fn build_review_from_route(
&self,
route: &ReviewRouteConfig,
strategy: &crate::multi_channel::MultiChannelStrategy,
) -> Result<Box<dyn ReviewChannel>, ReviewChannelError> {
let configs = route.configs();
if configs.len() == 1 {
return self.build_review_from_config(configs[0]);
}
let mut channels: Vec<Box<dyn ReviewChannel>> = Vec::with_capacity(configs.len());
for config in configs {
channels.push(self.build_review_from_config(config)?);
}
Ok(Box::new(crate::multi_channel::MultiReviewChannel::new(
channels,
strategy.clone(),
)))
}
pub fn build_session_from_config(
&self,
route: &ChannelRouteConfig,
) -> Result<Box<dyn SessionChannel>, SessionChannelError> {
let factory = self.get(&route.channel_type).ok_or_else(|| {
SessionChannelError::Other(format!(
"unknown channel type: '{}'. Registered: {:?}",
route.channel_type,
self.channel_types()
))
})?;
factory.build_session(&route.config)
}
}
impl Default for ChannelRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChannelRouteConfig {
#[serde(rename = "type")]
pub channel_type: String,
#[serde(flatten)]
pub config: serde_json::Value,
}
impl Default for ChannelRouteConfig {
fn default() -> Self {
Self {
channel_type: "terminal".to_string(),
config: serde_json::Value::Object(serde_json::Map::new()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NotifyRouteConfig {
#[serde(rename = "type")]
pub channel_type: String,
#[serde(default = "default_notify_level")]
pub level: String,
#[serde(flatten)]
pub config: serde_json::Value,
}
fn default_notify_level() -> String {
"info".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ReviewRouteConfig {
Single(ChannelRouteConfig),
Multiple(Vec<ChannelRouteConfig>),
}
impl Default for ReviewRouteConfig {
fn default() -> Self {
Self::Single(ChannelRouteConfig::default())
}
}
impl ReviewRouteConfig {
pub fn configs(&self) -> Vec<&ChannelRouteConfig> {
match self {
Self::Single(c) => vec![c],
Self::Multiple(cs) => cs.iter().collect(),
}
}
pub fn is_multi(&self) -> bool {
matches!(self, Self::Multiple(cs) if cs.len() > 1)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EscalationRouteConfig {
Single(ChannelRouteConfig),
Multiple(Vec<ChannelRouteConfig>),
}
impl EscalationRouteConfig {
pub fn configs(&self) -> Vec<&ChannelRouteConfig> {
match self {
Self::Single(c) => vec![c],
Self::Multiple(cs) => cs.iter().collect(),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChannelRoutingConfig {
#[serde(default)]
pub review: ReviewRouteConfig,
#[serde(default)]
pub notify: Vec<NotifyRouteConfig>,
#[serde(default)]
pub session: ChannelRouteConfig,
#[serde(default)]
pub escalation: Option<EscalationRouteConfig>,
#[serde(default)]
pub default_agent: Option<String>,
#[serde(default)]
pub default_workflow: Option<String>,
#[serde(default)]
pub strategy: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TaConfig {
#[serde(default)]
pub channels: ChannelRoutingConfig,
}
pub fn load_config(project_root: &Path) -> TaConfig {
let config_path = project_root.join(".ta").join("config.yaml");
if !config_path.exists() {
return TaConfig::default();
}
match std::fs::read_to_string(&config_path) {
Ok(content) => serde_yaml::from_str(&content).unwrap_or_default(),
Err(_) => TaConfig::default(),
}
}
pub struct TerminalChannelFactory;
impl ChannelFactory for TerminalChannelFactory {
fn channel_type(&self) -> &str {
"terminal"
}
fn build_review(
&self,
_config: &serde_json::Value,
) -> Result<Box<dyn ReviewChannel>, ReviewChannelError> {
Ok(Box::new(crate::terminal_channel::TerminalChannel::stdio()))
}
fn build_session(
&self,
_config: &serde_json::Value,
) -> Result<Box<dyn SessionChannel>, SessionChannelError> {
Ok(Box::new(
crate::terminal_channel::TerminalSessionChannel::new(),
))
}
fn capabilities(&self) -> ChannelCapabilitySet {
ChannelCapabilitySet {
supports_review: true,
supports_session: true,
supports_notify: true,
supports_rich_media: false,
supports_threads: false,
}
}
}
pub struct AutoApproveChannelFactory;
impl ChannelFactory for AutoApproveChannelFactory {
fn channel_type(&self) -> &str {
"auto-approve"
}
fn build_review(
&self,
_config: &serde_json::Value,
) -> Result<Box<dyn ReviewChannel>, ReviewChannelError> {
Ok(Box::new(crate::terminal_channel::AutoApproveChannel::new()))
}
fn build_session(
&self,
_config: &serde_json::Value,
) -> Result<Box<dyn SessionChannel>, SessionChannelError> {
Ok(Box::new(
crate::terminal_channel::TerminalSessionChannel::new(),
))
}
fn capabilities(&self) -> ChannelCapabilitySet {
ChannelCapabilitySet {
supports_review: true,
supports_session: false,
supports_notify: false,
supports_rich_media: false,
supports_threads: false,
}
}
}
pub struct WebhookChannelFactory;
impl ChannelFactory for WebhookChannelFactory {
fn channel_type(&self) -> &str {
"webhook"
}
fn build_review(
&self,
config: &serde_json::Value,
) -> Result<Box<dyn ReviewChannel>, ReviewChannelError> {
let endpoint = config
.get("endpoint")
.and_then(|v| v.as_str())
.ok_or_else(|| {
ReviewChannelError::Other("webhook requires 'endpoint' in config".into())
})?;
Ok(Box::new(crate::webhook_channel::WebhookChannel::new(
endpoint,
)))
}
fn build_session(
&self,
_config: &serde_json::Value,
) -> Result<Box<dyn SessionChannel>, SessionChannelError> {
Err(SessionChannelError::Other(
"webhook does not support interactive sessions".into(),
))
}
fn capabilities(&self) -> ChannelCapabilitySet {
ChannelCapabilitySet {
supports_review: true,
supports_session: false,
supports_notify: true,
supports_rich_media: false,
supports_threads: false,
}
}
}
pub fn default_registry() -> ChannelRegistry {
let mut registry = ChannelRegistry::new();
registry.register(Box::new(TerminalChannelFactory));
registry.register(Box::new(AutoApproveChannelFactory));
registry.register(Box::new(WebhookChannelFactory));
registry
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_registry_has_builtins() {
let registry = default_registry();
assert!(registry.has_channel("terminal"));
assert!(registry.has_channel("auto-approve"));
assert!(registry.has_channel("webhook"));
assert!(!registry.has_channel("slack"));
assert_eq!(registry.len(), 3);
}
#[test]
fn build_review_from_config() {
let registry = default_registry();
let route = ChannelRouteConfig {
channel_type: "terminal".into(),
config: serde_json::json!({}),
};
let channel = registry.build_review_from_config(&route);
assert!(channel.is_ok());
}
#[test]
fn build_review_unknown_type_errors() {
let registry = default_registry();
let route = ChannelRouteConfig {
channel_type: "slack".into(),
config: serde_json::json!({}),
};
let result = registry.build_review_from_config(&route);
assert!(result.is_err());
}
#[test]
fn channel_routing_config_single_review() {
let yaml = r#"
review:
type: terminal
notify:
- type: terminal
- type: webhook
endpoint: "/tmp/notify"
level: warning
session:
type: terminal
escalation:
type: webhook
endpoint: "/tmp/escalate"
default_agent: claude-code
"#;
let config: ChannelRoutingConfig = serde_yaml::from_str(yaml).unwrap();
let review_configs = config.review.configs();
assert_eq!(review_configs.len(), 1);
assert_eq!(review_configs[0].channel_type, "terminal");
assert!(!config.review.is_multi());
assert_eq!(config.notify.len(), 2);
assert_eq!(config.notify[1].channel_type, "webhook");
assert_eq!(config.notify[1].level, "warning");
assert!(config.escalation.is_some());
assert_eq!(config.default_agent.as_deref(), Some("claude-code"));
}
#[test]
fn channel_routing_config_multi_review() {
let yaml = r#"
review:
- type: terminal
- type: webhook
endpoint: "/tmp/review"
session:
type: terminal
strategy: first_response
"#;
let config: ChannelRoutingConfig = serde_yaml::from_str(yaml).unwrap();
let review_configs = config.review.configs();
assert_eq!(review_configs.len(), 2);
assert_eq!(review_configs[0].channel_type, "terminal");
assert_eq!(review_configs[1].channel_type, "webhook");
assert!(config.review.is_multi());
assert_eq!(config.strategy.as_deref(), Some("first_response"));
}
#[test]
fn channel_routing_config_multi_escalation() {
let yaml = r#"
review:
type: terminal
session:
type: terminal
escalation:
- type: webhook
endpoint: "/tmp/esc1"
- type: webhook
endpoint: "/tmp/esc2"
"#;
let config: ChannelRoutingConfig = serde_yaml::from_str(yaml).unwrap();
let esc = config.escalation.unwrap();
assert_eq!(esc.configs().len(), 2);
}
#[test]
fn ta_config_deserialization() {
let yaml = r#"
channels:
review:
type: terminal
session:
type: terminal
"#;
let config: TaConfig = serde_yaml::from_str(yaml).unwrap();
let review_configs = config.channels.review.configs();
assert_eq!(review_configs[0].channel_type, "terminal");
}
#[test]
fn default_ta_config() {
let config = TaConfig::default();
let review_configs = config.channels.review.configs();
assert_eq!(review_configs[0].channel_type, "terminal");
assert!(config.channels.notify.is_empty());
}
#[test]
fn build_multi_review_from_route_single() {
let registry = default_registry();
let route = ReviewRouteConfig::Single(ChannelRouteConfig {
channel_type: "terminal".into(),
config: serde_json::json!({}),
});
let strategy = crate::multi_channel::MultiChannelStrategy::FirstResponse;
let channel = registry.build_review_from_route(&route, &strategy);
assert!(channel.is_ok());
}
#[test]
fn build_multi_review_from_route_multiple() {
let registry = default_registry();
let route = ReviewRouteConfig::Multiple(vec![
ChannelRouteConfig {
channel_type: "auto-approve".into(),
config: serde_json::json!({}),
},
ChannelRouteConfig {
channel_type: "auto-approve".into(),
config: serde_json::json!({}),
},
]);
let strategy = crate::multi_channel::MultiChannelStrategy::FirstResponse;
let channel = registry.build_review_from_route(&route, &strategy);
assert!(channel.is_ok());
}
#[test]
fn channel_capability_set_defaults() {
let caps = ChannelCapabilitySet::default();
assert!(caps.supports_review);
assert!(caps.supports_session);
assert!(caps.supports_notify);
assert!(!caps.supports_rich_media);
assert!(!caps.supports_threads);
}
#[test]
fn register_custom_factory() {
struct MockFactory;
impl ChannelFactory for MockFactory {
fn channel_type(&self) -> &str {
"mock"
}
fn build_review(
&self,
_config: &serde_json::Value,
) -> Result<Box<dyn ReviewChannel>, ReviewChannelError> {
Ok(Box::new(crate::terminal_channel::AutoApproveChannel::new()))
}
fn build_session(
&self,
_config: &serde_json::Value,
) -> Result<Box<dyn SessionChannel>, SessionChannelError> {
Err(SessionChannelError::Other("mock".into()))
}
fn capabilities(&self) -> ChannelCapabilitySet {
ChannelCapabilitySet::default()
}
}
let mut registry = default_registry();
registry.register(Box::new(MockFactory));
assert!(registry.has_channel("mock"));
assert_eq!(registry.len(), 4);
}
#[test]
fn load_config_missing_file() {
let dir = tempfile::TempDir::new().unwrap();
let config = load_config(dir.path());
assert_eq!(config.channels.review.configs()[0].channel_type, "terminal");
}
#[test]
fn webhook_factory_requires_endpoint() {
let registry = default_registry();
let route = ChannelRouteConfig {
channel_type: "webhook".into(),
config: serde_json::json!({}),
};
let result = registry.build_review_from_config(&route);
assert!(result.is_err());
}
}