1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::Json;
5use axum::extract::{Path, State};
6use serde::{Deserialize, Serialize};
7
8use lago_core::EventQuery;
9use lago_core::event::{EventEnvelope, EventPayload};
10use lago_core::id::{BranchId, EventId, SeqNo, SessionId};
11
12use crate::error::ApiError;
13use crate::state::AppState;
14
15#[derive(Deserialize)]
18pub struct CreateBranchRequest {
19 pub name: String,
20 #[serde(default)]
21 pub fork_point_seq: Option<SeqNo>,
22}
23
24#[derive(Deserialize)]
25pub struct MergeBranchRequest {
26 pub target: String,
27}
28
29#[derive(Serialize)]
30pub struct MergeBranchResponse {
31 pub merged: bool,
32 pub strategy: String,
33 pub events_merged: usize,
34}
35
36#[derive(Serialize)]
37pub struct BranchResponse {
38 pub branch_id: String,
39 pub name: String,
40 pub fork_point_seq: SeqNo,
41}
42
43pub async fn create_branch(
50 State(state): State<Arc<AppState>>,
51 Path(session_id): Path<String>,
52 Json(body): Json<CreateBranchRequest>,
53) -> Result<(axum::http::StatusCode, Json<BranchResponse>), ApiError> {
54 let session_id = SessionId::from_string(session_id.clone());
55 let main_branch = BranchId::from_string("main");
56
57 state
59 .journal
60 .get_session(&session_id)
61 .await?
62 .ok_or_else(|| ApiError::NotFound(format!("session not found: {session_id}")))?;
63
64 let fork_point_seq = match body.fork_point_seq {
66 Some(seq) => seq,
67 None => state.journal.head_seq(&session_id, &main_branch).await?,
68 };
69
70 let new_branch_id = BranchId::new();
71
72 let event = EventEnvelope {
74 event_id: EventId::new(),
75 session_id: session_id.clone(),
76 branch_id: main_branch.clone(),
77 run_id: None,
78 seq: 0, timestamp: EventEnvelope::now_micros(),
80 parent_id: None,
81 payload: EventPayload::BranchCreated {
82 new_branch_id: new_branch_id.clone().into(),
83 fork_point_seq,
84 name: body.name.clone(),
85 },
86 metadata: HashMap::new(),
87 schema_version: 1,
88 };
89
90 state.journal.append(event).await?;
91
92 Ok((
93 axum::http::StatusCode::CREATED,
94 Json(BranchResponse {
95 branch_id: new_branch_id.to_string(),
96 name: body.name,
97 fork_point_seq,
98 }),
99 ))
100}
101
102pub async fn list_branches(
107 State(state): State<Arc<AppState>>,
108 Path(session_id): Path<String>,
109) -> Result<Json<Vec<BranchResponse>>, ApiError> {
110 let session_id = SessionId::from_string(session_id.clone());
111
112 let _session = state
114 .journal
115 .get_session(&session_id)
116 .await?
117 .ok_or_else(|| ApiError::NotFound(format!("session not found: {session_id}")))?;
118
119 let query = lago_core::EventQuery::new().session(session_id.clone());
121 let events = state.journal.read(query).await?;
122
123 let mut branches: Vec<BranchResponse> = Vec::new();
124
125 branches.push(BranchResponse {
127 branch_id: "main".to_string(),
128 name: "main".to_string(),
129 fork_point_seq: 0,
130 });
131
132 for event in &events {
134 if let EventPayload::BranchCreated {
135 ref new_branch_id,
136 fork_point_seq,
137 ref name,
138 } = event.payload
139 {
140 branches.push(BranchResponse {
141 branch_id: new_branch_id.as_str().to_string(),
142 name: name.clone(),
143 fork_point_seq,
144 });
145 }
146 }
147
148 Ok(Json(branches))
149}
150
151pub async fn merge_branch(
158 State(state): State<Arc<AppState>>,
159 Path((session_id, source_branch_name)): Path<(String, String)>,
160 Json(body): Json<MergeBranchRequest>,
161) -> Result<Json<MergeBranchResponse>, ApiError> {
162 let session_id = SessionId::from_string(session_id.clone());
163
164 state
166 .journal
167 .get_session(&session_id)
168 .await?
169 .ok_or_else(|| ApiError::NotFound(format!("session not found: {session_id}")))?;
170
171 let all_events_query = EventQuery::new().session(session_id.clone());
174 let all_events = state.journal.read(all_events_query).await?;
175
176 let mut name_to_id: HashMap<String, BranchId> = HashMap::new();
178 name_to_id.insert("main".to_string(), BranchId::from_string("main"));
179
180 let mut branch_fork_points: HashMap<String, SeqNo> = HashMap::new();
182 branch_fork_points.insert("main".to_string(), 0);
183
184 for event in &all_events {
185 if let EventPayload::BranchCreated {
186 ref new_branch_id,
187 fork_point_seq,
188 ref name,
189 } = event.payload
190 {
191 let lago_branch_id = BranchId::from_string(new_branch_id.as_str());
192 name_to_id.insert(name.clone(), lago_branch_id);
193 branch_fork_points.insert(name.clone(), fork_point_seq);
194 }
195 }
196
197 let source_branch_id = name_to_id
199 .get(&source_branch_name)
200 .cloned()
201 .ok_or_else(|| {
202 ApiError::NotFound(format!("source branch not found: {source_branch_name}"))
203 })?;
204
205 let source_fork_point = branch_fork_points
206 .get(&source_branch_name)
207 .copied()
208 .unwrap_or(0);
209
210 let target_branch_id = name_to_id
212 .get(&body.target)
213 .cloned()
214 .ok_or_else(|| ApiError::NotFound(format!("target branch not found: {}", body.target)))?;
215
216 let target_events_query = EventQuery::new()
221 .session(session_id.clone())
222 .branch(target_branch_id.clone())
223 .after(source_fork_point);
224 let target_events_after_fork = state.journal.read(target_events_query).await?;
225
226 let has_content_divergence = target_events_after_fork
227 .iter()
228 .any(|e| !matches!(e.payload, EventPayload::BranchCreated { .. }));
229
230 if has_content_divergence {
231 let target_head_seq = state
232 .journal
233 .head_seq(&session_id, &target_branch_id)
234 .await?;
235 return Err(ApiError::Conflict(format!(
236 "fast-forward not possible: source branch '{}' forked at seq {} \
237 but target branch '{}' has diverged (head at seq {}). \
238 A three-way merge is required.",
239 source_branch_name, source_fork_point, body.target, target_head_seq
240 )));
241 }
242
243 let source_query = EventQuery::new()
245 .session(session_id.clone())
246 .branch(source_branch_id.clone());
247 let source_events = state.journal.read(source_query).await?;
248
249 let mut merged_events: Vec<EventEnvelope> = Vec::new();
251 for source_event in &source_events {
252 let mut merged = source_event.clone();
253 merged.event_id = EventId::new();
254 merged.branch_id = target_branch_id.clone();
255 merged.seq = 0; merged.timestamp = EventEnvelope::now_micros();
257 merged_events.push(merged);
258 }
259
260 let events_merged = merged_events.len();
261
262 if !merged_events.is_empty() {
264 state.journal.append_batch(merged_events).await?;
265 }
266
267 let merge_event = EventEnvelope {
269 event_id: EventId::new(),
270 session_id: session_id.clone(),
271 branch_id: target_branch_id.clone(),
272 run_id: None,
273 seq: 0, timestamp: EventEnvelope::now_micros(),
275 parent_id: None,
276 payload: EventPayload::Custom {
277 event_type: "BranchMerged".to_string(),
278 data: serde_json::json!({
279 "source_branch": source_branch_name,
280 "source_branch_id": source_branch_id.as_str(),
281 "target_branch": body.target,
282 "target_branch_id": target_branch_id.as_str(),
283 "strategy": "fast-forward",
284 "events_merged": events_merged,
285 }),
286 },
287 metadata: HashMap::new(),
288 schema_version: 1,
289 };
290
291 state.journal.append(merge_event).await?;
292
293 Ok(Json(MergeBranchResponse {
294 merged: true,
295 strategy: "fast-forward".to_string(),
296 events_merged,
297 }))
298}