Skip to main content

lago_api/routes/
branches.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::Json;
5use axum::extract::{Path, State};
6use serde::{Deserialize, Serialize};
7
8use lago_core::EventQuery;
9use lago_core::event::{EventEnvelope, EventPayload};
10use lago_core::id::{BranchId, EventId, SeqNo, SessionId};
11
12use crate::error::ApiError;
13use crate::state::AppState;
14
15// --- Request / Response types
16
17#[derive(Deserialize)]
18pub struct CreateBranchRequest {
19    pub name: String,
20    #[serde(default)]
21    pub fork_point_seq: Option<SeqNo>,
22}
23
24#[derive(Deserialize)]
25pub struct MergeBranchRequest {
26    pub target: String,
27}
28
29#[derive(Serialize)]
30pub struct MergeBranchResponse {
31    pub merged: bool,
32    pub strategy: String,
33    pub events_merged: usize,
34}
35
36#[derive(Serialize)]
37pub struct BranchResponse {
38    pub branch_id: String,
39    pub name: String,
40    pub fork_point_seq: SeqNo,
41}
42
43// --- Handlers
44
45/// POST /v1/sessions/:id/branches
46///
47/// Creates a new branch forked from the session's "main" branch at the
48/// given sequence number (defaults to the current head).
49pub async fn create_branch(
50    State(state): State<Arc<AppState>>,
51    Path(session_id): Path<String>,
52    Json(body): Json<CreateBranchRequest>,
53) -> Result<(axum::http::StatusCode, Json<BranchResponse>), ApiError> {
54    let session_id = SessionId::from_string(session_id.clone());
55    let main_branch = BranchId::from_string("main");
56
57    // Verify the session exists
58    state
59        .journal
60        .get_session(&session_id)
61        .await?
62        .ok_or_else(|| ApiError::NotFound(format!("session not found: {session_id}")))?;
63
64    // Determine fork point: use the provided seq or the current head
65    let fork_point_seq = match body.fork_point_seq {
66        Some(seq) => seq,
67        None => state.journal.head_seq(&session_id, &main_branch).await?,
68    };
69
70    let new_branch_id = BranchId::new();
71
72    // Emit a BranchCreated event on the main branch
73    let event = EventEnvelope {
74        event_id: EventId::new(),
75        session_id: session_id.clone(),
76        branch_id: main_branch.clone(),
77        run_id: None,
78        seq: 0, // Will be assigned by the journal
79        timestamp: EventEnvelope::now_micros(),
80        parent_id: None,
81        payload: EventPayload::BranchCreated {
82            new_branch_id: new_branch_id.clone().into(),
83            fork_point_seq,
84            name: body.name.clone(),
85        },
86        metadata: HashMap::new(),
87        schema_version: 1,
88    };
89
90    state.journal.append(event).await?;
91
92    Ok((
93        axum::http::StatusCode::CREATED,
94        Json(BranchResponse {
95            branch_id: new_branch_id.to_string(),
96            name: body.name,
97            fork_point_seq,
98        }),
99    ))
100}
101
102/// GET /v1/sessions/:id/branches
103///
104/// Lists all branches for a session. Currently reads BranchCreated events
105/// from the journal to reconstruct the branch list.
106pub async fn list_branches(
107    State(state): State<Arc<AppState>>,
108    Path(session_id): Path<String>,
109) -> Result<Json<Vec<BranchResponse>>, ApiError> {
110    let session_id = SessionId::from_string(session_id.clone());
111
112    // Verify the session exists
113    let _session = state
114        .journal
115        .get_session(&session_id)
116        .await?
117        .ok_or_else(|| ApiError::NotFound(format!("session not found: {session_id}")))?;
118
119    // Read all events for this session to find BranchCreated events
120    let query = lago_core::EventQuery::new().session(session_id.clone());
121    let events = state.journal.read(query).await?;
122
123    let mut branches: Vec<BranchResponse> = Vec::new();
124
125    // The "main" branch always exists for a session
126    branches.push(BranchResponse {
127        branch_id: "main".to_string(),
128        name: "main".to_string(),
129        fork_point_seq: 0,
130    });
131
132    // Extract BranchCreated events
133    for event in &events {
134        if let EventPayload::BranchCreated {
135            ref new_branch_id,
136            fork_point_seq,
137            ref name,
138        } = event.payload
139        {
140            branches.push(BranchResponse {
141                branch_id: new_branch_id.as_str().to_string(),
142                name: name.clone(),
143                fork_point_seq,
144            });
145        }
146    }
147
148    Ok(Json(branches))
149}
150
151/// POST /v1/sessions/:id/branches/:branch/merge
152///
153/// Merges the named branch INTO the target branch specified in the request body.
154/// Phase 1: Fast-forward merge only — succeeds when the source branch's
155/// fork_point_seq >= the target branch's head_seq, meaning the target has not
156/// diverged. Returns 409 Conflict if a fast-forward is not possible.
157pub async fn merge_branch(
158    State(state): State<Arc<AppState>>,
159    Path((session_id, source_branch_name)): Path<(String, String)>,
160    Json(body): Json<MergeBranchRequest>,
161) -> Result<Json<MergeBranchResponse>, ApiError> {
162    let session_id = SessionId::from_string(session_id.clone());
163
164    // Verify the session exists
165    state
166        .journal
167        .get_session(&session_id)
168        .await?
169        .ok_or_else(|| ApiError::NotFound(format!("session not found: {session_id}")))?;
170
171    // Resolve source and target branch IDs.
172    // We need to scan BranchCreated events to map branch names to IDs.
173    let all_events_query = EventQuery::new().session(session_id.clone());
174    let all_events = state.journal.read(all_events_query).await?;
175
176    // Build a name-to-id map from BranchCreated events. "main" always exists.
177    let mut name_to_id: HashMap<String, BranchId> = HashMap::new();
178    name_to_id.insert("main".to_string(), BranchId::from_string("main"));
179
180    // Also track fork_point_seq per branch
181    let mut branch_fork_points: HashMap<String, SeqNo> = HashMap::new();
182    branch_fork_points.insert("main".to_string(), 0);
183
184    for event in &all_events {
185        if let EventPayload::BranchCreated {
186            ref new_branch_id,
187            fork_point_seq,
188            ref name,
189        } = event.payload
190        {
191            let lago_branch_id = BranchId::from_string(new_branch_id.as_str());
192            name_to_id.insert(name.clone(), lago_branch_id);
193            branch_fork_points.insert(name.clone(), fork_point_seq);
194        }
195    }
196
197    // Look up source branch
198    let source_branch_id = name_to_id
199        .get(&source_branch_name)
200        .cloned()
201        .ok_or_else(|| {
202            ApiError::NotFound(format!("source branch not found: {source_branch_name}"))
203        })?;
204
205    let source_fork_point = branch_fork_points
206        .get(&source_branch_name)
207        .copied()
208        .unwrap_or(0);
209
210    // Look up target branch
211    let target_branch_id = name_to_id
212        .get(&body.target)
213        .cloned()
214        .ok_or_else(|| ApiError::NotFound(format!("target branch not found: {}", body.target)))?;
215
216    // Fast-forward check: verify the target branch has not received any
217    // content events after the source's fork point. BranchCreated events
218    // are metadata (they record the creation of other branches) and do not
219    // constitute divergence, so we exclude them from the check.
220    let target_events_query = EventQuery::new()
221        .session(session_id.clone())
222        .branch(target_branch_id.clone())
223        .after(source_fork_point);
224    let target_events_after_fork = state.journal.read(target_events_query).await?;
225
226    let has_content_divergence = target_events_after_fork
227        .iter()
228        .any(|e| !matches!(e.payload, EventPayload::BranchCreated { .. }));
229
230    if has_content_divergence {
231        let target_head_seq = state
232            .journal
233            .head_seq(&session_id, &target_branch_id)
234            .await?;
235        return Err(ApiError::Conflict(format!(
236            "fast-forward not possible: source branch '{}' forked at seq {} \
237             but target branch '{}' has diverged (head at seq {}). \
238             A three-way merge is required.",
239            source_branch_name, source_fork_point, body.target, target_head_seq
240        )));
241    }
242
243    // Read all events from the source branch
244    let source_query = EventQuery::new()
245        .session(session_id.clone())
246        .branch(source_branch_id.clone());
247    let source_events = state.journal.read(source_query).await?;
248
249    // Copy each source event to the target branch with new event IDs
250    let mut merged_events: Vec<EventEnvelope> = Vec::new();
251    for source_event in &source_events {
252        let mut merged = source_event.clone();
253        merged.event_id = EventId::new();
254        merged.branch_id = target_branch_id.clone();
255        merged.seq = 0; // Will be assigned by the journal
256        merged.timestamp = EventEnvelope::now_micros();
257        merged_events.push(merged);
258    }
259
260    let events_merged = merged_events.len();
261
262    // Append all merged events to the target branch
263    if !merged_events.is_empty() {
264        state.journal.append_batch(merged_events).await?;
265    }
266
267    // Emit a BranchMerged custom event on the target branch
268    let merge_event = EventEnvelope {
269        event_id: EventId::new(),
270        session_id: session_id.clone(),
271        branch_id: target_branch_id.clone(),
272        run_id: None,
273        seq: 0, // Will be assigned by the journal
274        timestamp: EventEnvelope::now_micros(),
275        parent_id: None,
276        payload: EventPayload::Custom {
277            event_type: "BranchMerged".to_string(),
278            data: serde_json::json!({
279                "source_branch": source_branch_name,
280                "source_branch_id": source_branch_id.as_str(),
281                "target_branch": body.target,
282                "target_branch_id": target_branch_id.as_str(),
283                "strategy": "fast-forward",
284                "events_merged": events_merged,
285            }),
286        },
287        metadata: HashMap::new(),
288        schema_version: 1,
289    };
290
291    state.journal.append(merge_event).await?;
292
293    Ok(Json(MergeBranchResponse {
294        merged: true,
295        strategy: "fast-forward".to_string(),
296        events_merged,
297    }))
298}