use actix_web::{web, HttpResponse};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionBranch {
pub id: String,
pub name: String,
pub session_id: String,
pub parent_branch_id: Option<String>,
pub fork_point_message_id: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub description: Option<String>,
pub is_default: bool,
pub is_archived: bool,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchedMessage {
pub id: String,
pub branch_id: String,
pub parent_message_id: Option<String>,
pub role: String,
pub content: String,
pub timestamp: DateTime<Utc>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForkRequest {
pub session_id: String,
pub fork_from_message_id: String,
pub branch_name: String,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MergeRequest {
pub source_branch_id: String,
pub target_branch_id: String,
pub strategy: MergeStrategy,
pub conflict_resolution: Option<ConflictResolution>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum MergeStrategy {
#[default]
Append,
Interleave,
KeepTarget,
KeepSource,
Manual,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConflictResolution {
pub timestamp_conflict: TimestampConflictStrategy,
pub duplicate_content: DuplicateStrategy,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TimestampConflictStrategy {
#[default]
KeepBoth,
PreferSource,
PreferTarget,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum DuplicateStrategy {
#[default]
KeepFirst,
KeepLast,
KeepBoth,
Remove,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MergeResult {
pub success: bool,
pub result_branch_id: String,
pub messages_merged: usize,
pub conflicts: Vec<MergeConflict>,
pub merge_message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MergeConflict {
pub id: String,
pub conflict_type: ConflictType,
pub source_message: Option<BranchedMessage>,
pub target_message: Option<BranchedMessage>,
pub resolved: bool,
pub resolution: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConflictType {
TimestampCollision,
DuplicateContent,
DivergentPath,
OrphanedMessage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BranchComparison {
pub base_branch_id: String,
pub compare_branch_id: String,
pub common_ancestor_id: Option<String>,
pub base_only: Vec<BranchedMessage>,
pub compare_only: Vec<BranchedMessage>,
pub common: Vec<BranchedMessage>,
pub ahead: usize,
pub behind: usize,
}
pub struct BranchingService {
branches: std::sync::RwLock<HashMap<String, SessionBranch>>,
messages: std::sync::RwLock<HashMap<String, Vec<BranchedMessage>>>,
}
impl BranchingService {
pub fn new() -> Self {
Self {
branches: std::sync::RwLock::new(HashMap::new()),
messages: std::sync::RwLock::new(HashMap::new()),
}
}
pub async fn create_branch(&self, request: ForkRequest) -> Result<SessionBranch, String> {
let branch = SessionBranch {
id: Uuid::new_v4().to_string(),
name: request.branch_name,
session_id: request.session_id.clone(),
parent_branch_id: None, fork_point_message_id: request.fork_from_message_id.clone(),
created_at: Utc::now(),
updated_at: Utc::now(),
description: request.description,
is_default: false,
is_archived: false,
metadata: HashMap::new(),
};
let branch_id = branch.id.clone();
self.copy_messages_to_branch(
&request.session_id,
&branch_id,
&request.fork_from_message_id,
)?;
let mut branches = self.branches.write().map_err(|e| e.to_string())?;
branches.insert(branch.id.clone(), branch.clone());
Ok(branch)
}
pub async fn list_branches(&self, session_id: &str) -> Result<Vec<SessionBranch>, String> {
let branches = self.branches.read().map_err(|e| e.to_string())?;
let result: Vec<SessionBranch> = branches
.values()
.filter(|b| b.session_id == session_id)
.cloned()
.collect();
Ok(result)
}
pub async fn get_branch(&self, branch_id: &str) -> Result<SessionBranch, String> {
let branches = self.branches.read().map_err(|e| e.to_string())?;
branches
.get(branch_id)
.cloned()
.ok_or_else(|| format!("Branch not found: {}", branch_id))
}
pub async fn update_branch(
&self,
branch_id: &str,
name: Option<String>,
description: Option<String>,
) -> Result<SessionBranch, String> {
let mut branches = self.branches.write().map_err(|e| e.to_string())?;
let branch = branches
.get_mut(branch_id)
.ok_or_else(|| format!("Branch not found: {}", branch_id))?;
if let Some(n) = name {
branch.name = n;
}
if let Some(d) = description {
branch.description = Some(d);
}
branch.updated_at = Utc::now();
Ok(branch.clone())
}
pub async fn delete_branch(&self, branch_id: &str) -> Result<(), String> {
let mut branches = self.branches.write().map_err(|e| e.to_string())?;
let branch = branches
.get(branch_id)
.ok_or_else(|| format!("Branch not found: {}", branch_id))?;
if branch.is_default {
return Err("Cannot delete the default branch".to_string());
}
branches.remove(branch_id);
let mut messages = self.messages.write().map_err(|e| e.to_string())?;
messages.remove(branch_id);
Ok(())
}
pub async fn archive_branch(&self, branch_id: &str) -> Result<SessionBranch, String> {
let mut branches = self.branches.write().map_err(|e| e.to_string())?;
let branch = branches
.get_mut(branch_id)
.ok_or_else(|| format!("Branch not found: {}", branch_id))?;
branch.is_archived = true;
branch.updated_at = Utc::now();
Ok(branch.clone())
}
pub async fn merge_branches(&self, request: MergeRequest) -> Result<MergeResult, String> {
let source_messages = self.get_branch_messages(&request.source_branch_id).await?;
let target_messages = self.get_branch_messages(&request.target_branch_id).await?;
let mut conflicts = Vec::new();
let mut merged_messages = Vec::new();
match request.strategy {
MergeStrategy::Append => {
merged_messages.extend(target_messages.clone());
for msg in source_messages {
if !target_messages
.iter()
.any(|t| t.content == msg.content && t.role == msg.role)
{
merged_messages.push(msg);
}
}
}
MergeStrategy::Interleave => {
let mut all_messages: Vec<_> = target_messages
.into_iter()
.chain(source_messages.into_iter())
.collect();
all_messages.sort_by_key(|m| m.timestamp);
for i in 1..all_messages.len() {
if all_messages[i].timestamp == all_messages[i - 1].timestamp
&& all_messages[i].branch_id != all_messages[i - 1].branch_id
{
conflicts.push(MergeConflict {
id: Uuid::new_v4().to_string(),
conflict_type: ConflictType::TimestampCollision,
source_message: Some(all_messages[i].clone()),
target_message: Some(all_messages[i - 1].clone()),
resolved: false,
resolution: None,
});
}
}
merged_messages = all_messages;
}
MergeStrategy::KeepTarget => {
merged_messages = target_messages;
}
MergeStrategy::KeepSource => {
merged_messages = source_messages;
}
MergeStrategy::Manual => {
merged_messages.extend(target_messages);
for msg in source_messages {
conflicts.push(MergeConflict {
id: Uuid::new_v4().to_string(),
conflict_type: ConflictType::DivergentPath,
source_message: Some(msg),
target_message: None,
resolved: false,
resolution: None,
});
}
}
}
{
let mut messages = self.messages.write().map_err(|e| e.to_string())?;
messages.insert(request.target_branch_id.clone(), merged_messages.clone());
}
Ok(MergeResult {
success: conflicts.is_empty(),
result_branch_id: request.target_branch_id,
messages_merged: merged_messages.len(),
conflicts,
merge_message: format!(
"Merged branch {} using {:?} strategy",
request.source_branch_id, request.strategy
),
})
}
pub async fn compare_branches(
&self,
base_branch_id: &str,
compare_branch_id: &str,
) -> Result<BranchComparison, String> {
let base_messages = self.get_branch_messages(base_branch_id).await?;
let compare_messages = self.get_branch_messages(compare_branch_id).await?;
let base_ids: std::collections::HashSet<_> =
base_messages.iter().map(|m| m.id.clone()).collect();
let compare_ids: std::collections::HashSet<_> =
compare_messages.iter().map(|m| m.id.clone()).collect();
let base_only: Vec<_> = base_messages
.iter()
.filter(|m| !compare_ids.contains(&m.id))
.cloned()
.collect();
let compare_only: Vec<_> = compare_messages
.iter()
.filter(|m| !base_ids.contains(&m.id))
.cloned()
.collect();
let common: Vec<_> = base_messages
.iter()
.filter(|m| compare_ids.contains(&m.id))
.cloned()
.collect();
Ok(BranchComparison {
base_branch_id: base_branch_id.to_string(),
compare_branch_id: compare_branch_id.to_string(),
common_ancestor_id: common.first().map(|m| m.id.clone()),
ahead: compare_only.len(),
behind: base_only.len(),
base_only,
compare_only,
common,
})
}
pub async fn get_branch_messages(
&self,
branch_id: &str,
) -> Result<Vec<BranchedMessage>, String> {
let messages = self.messages.read().map_err(|e| e.to_string())?;
Ok(messages.get(branch_id).cloned().unwrap_or_default())
}
pub async fn add_message(
&self,
branch_id: &str,
role: &str,
content: &str,
) -> Result<BranchedMessage, String> {
let message = BranchedMessage {
id: Uuid::new_v4().to_string(),
branch_id: branch_id.to_string(),
parent_message_id: None,
role: role.to_string(),
content: content.to_string(),
timestamp: Utc::now(),
metadata: HashMap::new(),
};
let mut messages = self.messages.write().map_err(|e| e.to_string())?;
messages
.entry(branch_id.to_string())
.or_default()
.push(message.clone());
Ok(message)
}
fn copy_messages_to_branch(
&self,
source_session_id: &str,
target_branch_id: &str,
up_to_message_id: &str,
) -> Result<(), String> {
let mut messages = self.messages.write().map_err(|e| e.to_string())?;
messages.insert(target_branch_id.to_string(), Vec::new());
Ok(())
}
}
impl Default for BranchingService {
fn default() -> Self {
Self::new()
}
}
pub async fn create_branch(
service: web::Data<BranchingService>,
request: web::Json<ForkRequest>,
) -> HttpResponse {
match service.create_branch(request.into_inner()).await {
Ok(branch) => HttpResponse::Created().json(branch),
Err(e) => HttpResponse::BadRequest().json(serde_json::json!({ "error": e })),
}
}
pub async fn list_branches(
service: web::Data<BranchingService>,
query: web::Query<HashMap<String, String>>,
) -> HttpResponse {
let session_id = match query.get("session_id") {
Some(id) => id,
None => {
return HttpResponse::BadRequest()
.json(serde_json::json!({ "error": "session_id required" }))
}
};
match service.list_branches(session_id).await {
Ok(branches) => HttpResponse::Ok().json(branches),
Err(e) => HttpResponse::InternalServerError().json(serde_json::json!({ "error": e })),
}
}
pub async fn get_branch(
service: web::Data<BranchingService>,
path: web::Path<String>,
) -> HttpResponse {
match service.get_branch(&path.into_inner()).await {
Ok(branch) => HttpResponse::Ok().json(branch),
Err(e) => HttpResponse::NotFound().json(serde_json::json!({ "error": e })),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateBranchRequest {
pub name: Option<String>,
pub description: Option<String>,
}
pub async fn update_branch(
service: web::Data<BranchingService>,
path: web::Path<String>,
body: web::Json<UpdateBranchRequest>,
) -> HttpResponse {
let request = body.into_inner();
match service
.update_branch(&path.into_inner(), request.name, request.description)
.await
{
Ok(branch) => HttpResponse::Ok().json(branch),
Err(e) => HttpResponse::BadRequest().json(serde_json::json!({ "error": e })),
}
}
pub async fn delete_branch(
service: web::Data<BranchingService>,
path: web::Path<String>,
) -> HttpResponse {
match service.delete_branch(&path.into_inner()).await {
Ok(()) => HttpResponse::NoContent().finish(),
Err(e) => HttpResponse::BadRequest().json(serde_json::json!({ "error": e })),
}
}
pub async fn archive_branch(
service: web::Data<BranchingService>,
path: web::Path<String>,
) -> HttpResponse {
match service.archive_branch(&path.into_inner()).await {
Ok(branch) => HttpResponse::Ok().json(branch),
Err(e) => HttpResponse::BadRequest().json(serde_json::json!({ "error": e })),
}
}
pub async fn merge_branches(
service: web::Data<BranchingService>,
request: web::Json<MergeRequest>,
) -> HttpResponse {
match service.merge_branches(request.into_inner()).await {
Ok(result) => HttpResponse::Ok().json(result),
Err(e) => HttpResponse::BadRequest().json(serde_json::json!({ "error": e })),
}
}
pub async fn compare_branches(
service: web::Data<BranchingService>,
path: web::Path<(String, String)>,
) -> HttpResponse {
let (base_id, compare_id) = path.into_inner();
match service.compare_branches(&base_id, &compare_id).await {
Ok(comparison) => HttpResponse::Ok().json(comparison),
Err(e) => HttpResponse::BadRequest().json(serde_json::json!({ "error": e })),
}
}
pub async fn get_branch_messages(
service: web::Data<BranchingService>,
path: web::Path<String>,
) -> HttpResponse {
match service.get_branch_messages(&path.into_inner()).await {
Ok(messages) => HttpResponse::Ok().json(messages),
Err(e) => HttpResponse::InternalServerError().json(serde_json::json!({ "error": e })),
}
}
#[derive(Debug, Deserialize)]
pub struct AddMessageRequest {
pub role: String,
pub content: String,
}
pub async fn add_message(
service: web::Data<BranchingService>,
path: web::Path<String>,
body: web::Json<AddMessageRequest>,
) -> HttpResponse {
let request = body.into_inner();
match service
.add_message(&path.into_inner(), &request.role, &request.content)
.await
{
Ok(message) => HttpResponse::Created().json(message),
Err(e) => HttpResponse::BadRequest().json(serde_json::json!({ "error": e })),
}
}
pub fn configure_branching_routes(cfg: &mut web::ServiceConfig) {
cfg.service(
web::scope("/branches")
.route("", web::post().to(create_branch))
.route("", web::get().to(list_branches))
.route("/{id}", web::get().to(get_branch))
.route("/{id}", web::patch().to(update_branch))
.route("/{id}", web::delete().to(delete_branch))
.route("/{id}/archive", web::post().to(archive_branch))
.route("/merge", web::post().to(merge_branches))
.route("/{id}/compare/{other_id}", web::get().to(compare_branches))
.route("/{id}/messages", web::get().to(get_branch_messages))
.route("/{id}/messages", web::post().to(add_message)),
);
}