Skip to main content

byokey_proxy/handler/management/
mod.rs

1//! BYOKEY management `ConnectRPC` services.
2//!
3//! Three services split by domain:
4//! - [`StatusServiceImpl`] — server health, usage, rate limits
5//! - [`AccountsServiceImpl`] — provider account CRUD
6//! - [`AmpServiceImpl`] — local Amp CLI thread browsing
7
8use std::sync::Arc;
9
10use buffa::MessageField;
11use buffa::view::OwnedView;
12use buffa_types::google::protobuf::value::Kind;
13use buffa_types::google::protobuf::{ListValue, NullValue, Struct, Value};
14use connectrpc::{ConnectError, Context, Router as ConnectRouter};
15use serde_json::Value as JsonValue;
16
17use byokey_proto::byokey::accounts as acct;
18use byokey_proto::byokey::amp as amp_pb;
19use byokey_proto::byokey::status as stat;
20
21use crate::AppState;
22use crate::handler::amp::threads as internal_threads;
23
24// ───────────────────────── public entry point ─────────────────────
25
26/// Build a [`ConnectRouter`] with all three management services registered.
27#[must_use]
28pub fn build_router(state: Arc<AppState>) -> ConnectRouter {
29    use acct::AccountsServiceExt as _;
30    use amp_pb::AmpServiceExt as _;
31    use stat::StatusServiceExt as _;
32
33    let router = ConnectRouter::new();
34    let router = Arc::new(StatusServiceImpl(state.clone())).register(router);
35    let router = Arc::new(AccountsServiceImpl(state.clone())).register(router);
36    Arc::new(AmpServiceImpl(state)).register(router)
37}
38
39// ───────────────────────── helpers ────────────────────────────────
40
41fn byok_to_connect_error(e: &byokey_types::ByokError) -> ConnectError {
42    use byokey_types::ByokError;
43    let msg = e.to_string();
44    match e {
45        ByokError::Auth(_) | ByokError::TokenNotFound(_) | ByokError::TokenExpired(_) => {
46            ConnectError::unauthenticated(msg)
47        }
48        ByokError::UnsupportedModel(_) => ConnectError::not_found(msg),
49        ByokError::UnsupportedProvider(_) | ByokError::Translation(_) => {
50            ConnectError::invalid_argument(msg)
51        }
52        ByokError::ProviderUnavailable(_) => ConnectError::unavailable(msg),
53        _ => ConnectError::internal(msg),
54    }
55}
56
57fn json_to_pb_value(v: JsonValue) -> Value {
58    let kind = match v {
59        JsonValue::Null => Kind::NullValue(NullValue::NULL_VALUE.into()),
60        JsonValue::Bool(b) => Kind::BoolValue(b),
61        JsonValue::Number(n) => Kind::NumberValue(n.as_f64().unwrap_or(0.0)),
62        JsonValue::String(s) => Kind::StringValue(s),
63        JsonValue::Array(arr) => {
64            let values = arr.into_iter().map(json_to_pb_value).collect();
65            Kind::ListValue(Box::new(ListValue {
66                values,
67                ..Default::default()
68            }))
69        }
70        JsonValue::Object(map) => {
71            let fields = map
72                .into_iter()
73                .map(|(k, v)| (k, json_to_pb_value(v)))
74                .collect();
75            Kind::StructValue(Box::new(Struct {
76                fields,
77                ..Default::default()
78            }))
79        }
80    };
81    Value {
82        kind: Some(kind),
83        ..Default::default()
84    }
85}
86
87fn json_to_pb_struct(v: JsonValue) -> Struct {
88    if let JsonValue::Object(map) = v {
89        let fields = map
90            .into_iter()
91            .map(|(k, v)| (k, json_to_pb_value(v)))
92            .collect();
93        Struct {
94            fields,
95            ..Default::default()
96        }
97    } else {
98        Struct::default()
99    }
100}
101
102fn clamp_to_u32(n: usize) -> u32 {
103    u32::try_from(n).unwrap_or(u32::MAX)
104}
105
106const THIRTY_DAYS_SECS: i64 = 30 * 24 * 3600;
107// Largest time range a single `GetUsageByAccount` request can ask for.
108// Prevents adversarial queries that would scan the entire usage table.
109const MAX_USAGE_RANGE_SECS: i64 = 365 * 24 * 3600;
110
111fn policy_strategy_to_proto(kind: byokey_config::PolicyStrategyKind) -> stat::RoutingStrategy {
112    use byokey_config::PolicyStrategyKind as K;
113    match kind {
114        K::RoundRobin => stat::RoutingStrategy::ROUTING_STRATEGY_ROUND_ROBIN,
115        K::WeightedRoundRobin => stat::RoutingStrategy::ROUTING_STRATEGY_WEIGHTED_ROUND_ROBIN,
116        K::Random => stat::RoutingStrategy::ROUTING_STRATEGY_RANDOM,
117        K::WeightedRandom => stat::RoutingStrategy::ROUTING_STRATEGY_WEIGHTED_RANDOM,
118        K::Priority => stat::RoutingStrategy::ROUTING_STRATEGY_PRIORITY,
119    }
120}
121
122#[allow(clippy::cast_possible_wrap)]
123fn now_seconds() -> i64 {
124    std::time::SystemTime::now()
125        .duration_since(std::time::UNIX_EPOCH)
126        .unwrap_or_default()
127        .as_secs() as i64
128}
129
130// ═══════════════════════ StatusService ════════════════════════════
131
132struct StatusServiceImpl(Arc<AppState>);
133
134impl stat::StatusService for StatusServiceImpl {
135    async fn get_status(
136        &self,
137        ctx: Context,
138        _: OwnedView<stat::GetStatusRequestView<'static>>,
139    ) -> Result<(stat::GetStatusResponse, Context), ConnectError> {
140        let snapshot = self.0.config.load();
141        let server = stat::ServerInfo {
142            host: snapshot.host.clone(),
143            port: u32::from(snapshot.port),
144            ..Default::default()
145        };
146
147        let mut providers = Vec::new();
148        for pid in byokey_types::ProviderId::all() {
149            let cfg = snapshot.providers.get(pid);
150            let has_key = cfg.is_some_and(|c| c.api_key.is_some() || !c.api_keys.is_empty());
151            let auth = if has_key || self.0.auth.is_authenticated(pid).await {
152                stat::AuthStatus::AUTH_STATUS_VALID
153            } else {
154                let accts = self.0.auth.list_accounts(pid).await.unwrap_or_default();
155                if accts.is_empty() {
156                    stat::AuthStatus::AUTH_STATUS_NOT_CONFIGURED
157                } else {
158                    stat::AuthStatus::AUTH_STATUS_EXPIRED
159                }
160            };
161            providers.push(stat::ProviderStatus {
162                id: pid.to_string(),
163                display_name: pid.display_name().to_string(),
164                enabled: cfg.is_none_or(|c| c.enabled),
165                auth_status: auth.into(),
166                models_count: clamp_to_u32(byokey_provider::models_for_provider(pid).len()),
167                ..Default::default()
168            });
169        }
170        Ok((
171            stat::GetStatusResponse {
172                server: server.into(),
173                providers,
174                ..Default::default()
175            },
176            ctx,
177        ))
178    }
179
180    async fn get_usage(
181        &self,
182        ctx: Context,
183        _: OwnedView<stat::GetUsageRequestView<'static>>,
184    ) -> Result<(stat::GetUsageResponse, Context), ConnectError> {
185        let s = self.0.usage.snapshot();
186        let models = s
187            .models
188            .into_iter()
189            .map(|(k, m)| {
190                (
191                    k,
192                    stat::ModelStats {
193                        requests: m.requests,
194                        success: m.success,
195                        failure: m.failure,
196                        input_tokens: m.input_tokens,
197                        output_tokens: m.output_tokens,
198                        ..Default::default()
199                    },
200                )
201            })
202            .collect();
203        Ok((
204            stat::GetUsageResponse {
205                total_requests: s.total_requests,
206                success_requests: s.success_requests,
207                failure_requests: s.failure_requests,
208                input_tokens: s.input_tokens,
209                output_tokens: s.output_tokens,
210                models,
211                ..Default::default()
212            },
213            ctx,
214        ))
215    }
216
217    async fn get_usage_history(
218        &self,
219        ctx: Context,
220        request: OwnedView<stat::GetUsageHistoryRequestView<'static>>,
221    ) -> Result<(stat::GetUsageHistoryResponse, Context), ConnectError> {
222        let req = request.to_owned_message();
223        let Some(store) = self.0.usage.store() else {
224            let to = now_seconds();
225            return Ok((
226                stat::GetUsageHistoryResponse {
227                    from: to - 86400,
228                    to,
229                    bucket_seconds: 3600,
230                    error: Some("no persistent usage store configured".into()),
231                    ..Default::default()
232                },
233                ctx,
234            ));
235        };
236        let to = req.to.unwrap_or_else(now_seconds);
237        let from = req.from.unwrap_or(to - 86400);
238        let range = to - from;
239        let bs = if range <= 86400 {
240            3600
241        } else if range <= 86400 * 7 {
242            21600
243        } else {
244            86400
245        };
246        let buckets = store
247            .query(from, to, req.model.as_deref(), bs)
248            .await
249            .map_err(|e| ConnectError::internal(e.to_string()))?
250            .into_iter()
251            .map(|b| stat::UsageBucket {
252                period_start: b.period_start,
253                model: b.model,
254                request_count: b.request_count,
255                input_tokens: b.input_tokens,
256                output_tokens: b.output_tokens,
257                ..Default::default()
258            })
259            .collect();
260        Ok((
261            stat::GetUsageHistoryResponse {
262                from,
263                to,
264                bucket_seconds: bs,
265                buckets,
266                error: None,
267                ..Default::default()
268            },
269            ctx,
270        ))
271    }
272
273    async fn get_usage_by_account(
274        &self,
275        ctx: Context,
276        request: OwnedView<stat::GetUsageByAccountRequestView<'static>>,
277    ) -> Result<(stat::GetUsageByAccountResponse, Context), ConnectError> {
278        let req = request.to_owned_message();
279        let Some(store) = self.0.usage.store() else {
280            return Ok((
281                stat::GetUsageByAccountResponse {
282                    rows: Vec::new(),
283                    error: Some("no persistent usage store configured".into()),
284                    ..Default::default()
285                },
286                ctx,
287            ));
288        };
289        let to = req.to.unwrap_or_else(now_seconds);
290        let from = req
291            .from
292            .unwrap_or_else(|| to.saturating_sub(THIRTY_DAYS_SECS));
293        if from > to {
294            return Err(ConnectError::invalid_argument(
295                "from must be less than or equal to to",
296            ));
297        }
298        // Reject unbounded ranges so adversarial clients can't force a
299        // full-table scan by requesting, e.g., epoch-0-to-now.
300        if to.saturating_sub(from) > MAX_USAGE_RANGE_SECS {
301            return Err(ConnectError::invalid_argument(format!(
302                "requested range exceeds maximum of {} days",
303                MAX_USAGE_RANGE_SECS / 86_400
304            )));
305        }
306        let totals = store
307            .totals_by_account(Some(from), Some(to))
308            .await
309            .map_err(|e| ConnectError::internal(e.to_string()))?;
310        let rows = totals
311            .into_iter()
312            .map(|t| stat::AccountUsageRow {
313                provider: t.provider,
314                account_id: t.account_id,
315                model: t.model,
316                request_count: t.request_count,
317                success_count: t.success_count,
318                input_tokens: t.input_tokens,
319                output_tokens: t.output_tokens,
320                ..Default::default()
321            })
322            .collect();
323        Ok((
324            stat::GetUsageByAccountResponse {
325                rows,
326                error: None,
327                ..Default::default()
328            },
329            ctx,
330        ))
331    }
332
333    async fn list_routing_policies(
334        &self,
335        ctx: Context,
336        _: OwnedView<stat::ListRoutingPoliciesRequestView<'static>>,
337    ) -> Result<(stat::ListRoutingPoliciesResponse, Context), ConnectError> {
338        let snapshot = self.0.config.load();
339        let policies = snapshot
340            .routing_policies
341            .iter()
342            .map(|entry| stat::RoutingPolicy {
343                provider: entry.provider.to_string(),
344                family: entry.family.clone().unwrap_or_default(),
345                strategy: policy_strategy_to_proto(entry.strategy).into(),
346                accounts: entry.accounts.clone(),
347                weights: entry.weights.clone(),
348                ..Default::default()
349            })
350            .collect();
351        Ok((
352            stat::ListRoutingPoliciesResponse {
353                policies,
354                ..Default::default()
355            },
356            ctx,
357        ))
358    }
359
360    async fn set_routing_policy(
361        &self,
362        _ctx: Context,
363        _: OwnedView<stat::SetRoutingPolicyRequestView<'static>>,
364    ) -> Result<(stat::SetRoutingPolicyResponse, Context), ConnectError> {
365        // TODO(slice-7): wire to ConfigWatcher hot-reload
366        // Server-side mutation of settings.json is not yet implemented.
367        // Edit settings.json directly until hot-reload is wired up.
368        Err(ConnectError::unimplemented(
369            "SetRoutingPolicy is not yet implemented; edit settings.json directly",
370        ))
371    }
372
373    async fn get_rate_limits(
374        &self,
375        ctx: Context,
376        _: OwnedView<stat::GetRateLimitsRequestView<'static>>,
377    ) -> Result<(stat::GetRateLimitsResponse, Context), ConnectError> {
378        let all = self.0.ratelimits.all();
379        let mut by_prov: std::collections::HashMap<
380            byokey_types::ProviderId,
381            Vec<stat::AccountRateLimit>,
382        > = std::collections::HashMap::new();
383        for ((prov, aid), snap) in all {
384            by_prov
385                .entry(prov)
386                .or_default()
387                .push(stat::AccountRateLimit {
388                    account_id: aid,
389                    snapshot: MessageField::some(stat::RateLimitSnapshot {
390                        headers: snap.headers,
391                        captured_at: snap.captured_at,
392                        ..Default::default()
393                    }),
394                    ..Default::default()
395                });
396        }
397        let providers = byokey_types::ProviderId::all()
398            .iter()
399            .filter_map(|pid| {
400                let accts = by_prov.remove(pid)?;
401                Some(stat::ProviderRateLimits {
402                    id: pid.to_string(),
403                    display_name: pid.display_name().to_string(),
404                    accounts: accts,
405                    ..Default::default()
406                })
407            })
408            .collect();
409        Ok((
410            stat::GetRateLimitsResponse {
411                providers,
412                ..Default::default()
413            },
414            ctx,
415        ))
416    }
417}
418
419// ═══════════════════════ AccountsService ══════════════════════════
420
421struct AccountsServiceImpl(Arc<AppState>);
422
423impl acct::AccountsService for AccountsServiceImpl {
424    async fn list_accounts(
425        &self,
426        ctx: Context,
427        _: OwnedView<acct::ListAccountsRequestView<'static>>,
428    ) -> Result<(acct::ListAccountsResponse, Context), ConnectError> {
429        let mut providers = Vec::new();
430        for pid in byokey_types::ProviderId::all() {
431            let infos = self.0.auth.list_accounts(pid).await.unwrap_or_default();
432            let tokens = self.0.auth.get_all_tokens(pid).await.unwrap_or_default();
433            let accounts = infos
434                .iter()
435                .map(|info| {
436                    let (ts, exp) = match tokens.iter().find(|(id, _)| id == &info.account_id) {
437                        Some((_, tok)) => {
438                            let s = match tok.state() {
439                                byokey_types::TokenState::Valid => {
440                                    acct::TokenState::TOKEN_STATE_VALID
441                                }
442                                byokey_types::TokenState::Expired => {
443                                    acct::TokenState::TOKEN_STATE_EXPIRED
444                                }
445                                byokey_types::TokenState::Invalid => {
446                                    acct::TokenState::TOKEN_STATE_INVALID
447                                }
448                            };
449                            (s, tok.expires_at)
450                        }
451                        None => (acct::TokenState::TOKEN_STATE_INVALID, None),
452                    };
453                    acct::AccountDetail {
454                        account_id: info.account_id.clone(),
455                        label: info.label.clone(),
456                        is_active: info.is_active,
457                        token_state: ts.into(),
458                        expires_at: exp,
459                        ..Default::default()
460                    }
461                })
462                .collect();
463            providers.push(acct::ProviderAccounts {
464                id: pid.to_string(),
465                display_name: pid.display_name().to_string(),
466                accounts,
467                ..Default::default()
468            });
469        }
470        Ok((
471            acct::ListAccountsResponse {
472                providers,
473                ..Default::default()
474            },
475            ctx,
476        ))
477    }
478
479    async fn remove_account(
480        &self,
481        ctx: Context,
482        request: OwnedView<acct::RemoveAccountRequestView<'static>>,
483    ) -> Result<(acct::RemoveAccountResponse, Context), ConnectError> {
484        let req = request.to_owned_message();
485        let pid: byokey_types::ProviderId = req
486            .provider
487            .parse()
488            .map_err(|e: byokey_types::ByokError| byok_to_connect_error(&e))?;
489        self.0
490            .auth
491            .remove_token_for(&pid, &req.account_id)
492            .await
493            .map_err(|e| byok_to_connect_error(&e))?;
494        Ok((acct::RemoveAccountResponse::default(), ctx))
495    }
496
497    async fn activate_account(
498        &self,
499        ctx: Context,
500        request: OwnedView<acct::ActivateAccountRequestView<'static>>,
501    ) -> Result<(acct::ActivateAccountResponse, Context), ConnectError> {
502        let req = request.to_owned_message();
503        let pid: byokey_types::ProviderId = req
504            .provider
505            .parse()
506            .map_err(|e: byokey_types::ByokError| byok_to_connect_error(&e))?;
507        self.0
508            .auth
509            .set_active_account(&pid, &req.account_id)
510            .await
511            .map_err(|e| byok_to_connect_error(&e))?;
512        Ok((acct::ActivateAccountResponse::default(), ctx))
513    }
514
515    async fn add_api_key(
516        &self,
517        ctx: Context,
518        request: OwnedView<acct::AddApiKeyRequestView<'static>>,
519    ) -> Result<(acct::AddApiKeyResponse, Context), ConnectError> {
520        let req = request.to_owned_message();
521        let pid: byokey_types::ProviderId = req
522            .provider
523            .parse()
524            .map_err(|e: byokey_types::ByokError| byok_to_connect_error(&e))?;
525        if req.api_key.trim().is_empty() {
526            return Err(ConnectError::invalid_argument("api_key cannot be empty"));
527        }
528        if req.api_key.len() > byokey_types::MAX_API_KEY_BYTES {
529            return Err(ConnectError::invalid_argument(format!(
530                "api_key exceeds maximum length of {} bytes",
531                byokey_types::MAX_API_KEY_BYTES
532            )));
533        }
534        let account_id = req
535            .account_id
536            .unwrap_or_else(|| byokey_types::DEFAULT_ACCOUNT.to_string());
537        let token = byokey_types::OAuthToken {
538            access_token: req.api_key.trim().to_string(),
539            refresh_token: None,
540            expires_at: None,
541            token_type: Some("api-key".to_string()),
542        };
543        self.0
544            .auth
545            .save_token_for(&pid, &account_id, req.label.as_deref(), token)
546            .await
547            .map_err(|e| byok_to_connect_error(&e))?;
548        Ok((
549            acct::AddApiKeyResponse {
550                account_id,
551                ..Default::default()
552            },
553            ctx,
554        ))
555    }
556
557    async fn import_claude_code(
558        &self,
559        ctx: Context,
560        request: OwnedView<acct::ImportClaudeCodeRequestView<'static>>,
561    ) -> Result<(acct::ImportClaudeCodeResponse, Context), ConnectError> {
562        let req = request.to_owned_message();
563        let token = byokey_auth::provider::claude_code::load_token()
564            .await
565            .map_err(|e| byok_to_connect_error(&e))?
566            .ok_or_else(|| {
567                ConnectError::failed_precondition(
568                    "no Claude Code credentials found — is Claude Code logged in on this machine?",
569                )
570            })?;
571        let pid = byokey_types::ProviderId::Claude;
572        let account_id = req
573            .account_id
574            .unwrap_or_else(|| byokey_types::CLAUDE_CODE_ACCOUNT.to_string());
575        let label = req.label.unwrap_or_else(|| "Claude Code".to_string());
576        self.0
577            .auth
578            .save_token_for(&pid, &account_id, Some(label.as_str()), token)
579            .await
580            .map_err(|e| byok_to_connect_error(&e))?;
581        Ok((
582            acct::ImportClaudeCodeResponse {
583                account_id,
584                ..Default::default()
585            },
586            ctx,
587        ))
588    }
589
590    async fn login(
591        &self,
592        ctx: Context,
593        request: OwnedView<acct::LoginRequestView<'static>>,
594    ) -> Result<
595        (
596            std::pin::Pin<
597                Box<dyn futures_util::Stream<Item = Result<acct::LoginEvent, ConnectError>> + Send>,
598            >,
599            Context,
600        ),
601        ConnectError,
602    > {
603        use futures_util::StreamExt as _;
604        use tokio_stream::wrappers::ReceiverStream;
605
606        let req = request.to_owned_message();
607        let pid: byokey_types::ProviderId = req
608            .provider
609            .parse()
610            .map_err(|e: byokey_types::ByokError| byok_to_connect_error(&e))?;
611        let account = req.account_id;
612
613        let (progress_tx, progress_rx) =
614            tokio::sync::mpsc::channel::<byokey_auth::flow::LoginProgress>(8);
615        let (event_tx, event_rx) =
616            tokio::sync::mpsc::channel::<Result<acct::LoginEvent, ConnectError>>(16);
617
618        let auth = self.0.auth.clone();
619        let event_tx_drive = event_tx.clone();
620        tokio::spawn(async move {
621            let mut progress_rx = progress_rx;
622            let account_ref = account.as_deref();
623            let login_fut =
624                byokey_auth::flow::login_with_events(&pid, &auth, account_ref, Some(progress_tx));
625            tokio::pin!(login_fut);
626
627            loop {
628                tokio::select! {
629                    biased;
630                    Some(p) = progress_rx.recv() => {
631                        let ev = progress_to_pb(&p);
632                        if event_tx_drive.send(Ok(ev)).await.is_err() { return; }
633                    }
634                    res = &mut login_fut => {
635                        // Drain any remaining progress events before emitting terminal.
636                        while let Ok(p) = progress_rx.try_recv() {
637                            let ev = progress_to_pb(&p);
638                            let _ = event_tx_drive.send(Ok(ev)).await;
639                        }
640                        let terminal = match res {
641                            Ok(()) => acct::LoginEvent {
642                                stage: acct::LoginStage::LOGIN_STAGE_DONE.into(),
643                                ..Default::default()
644                            },
645                            Err(e) => acct::LoginEvent {
646                                stage: acct::LoginStage::LOGIN_STAGE_FAILED.into(),
647                                error: Some(e.to_string()),
648                                ..Default::default()
649                            },
650                        };
651                        let _ = event_tx_drive.send(Ok(terminal)).await;
652                        return;
653                    }
654                }
655            }
656        });
657
658        let stream = ReceiverStream::new(event_rx).boxed();
659        Ok((stream, ctx))
660    }
661}
662
663fn progress_to_pb(p: &byokey_auth::flow::LoginProgress) -> acct::LoginEvent {
664    use byokey_auth::flow::LoginProgress as P;
665    let (stage, message, user_code) = match p {
666        P::Started => (acct::LoginStage::LOGIN_STAGE_STARTED, None, None),
667        P::OpenedBrowser { url, user_code } => (
668            acct::LoginStage::LOGIN_STAGE_OPENED_BROWSER,
669            Some(url.clone()),
670            user_code.clone(),
671        ),
672        P::GotCode => (acct::LoginStage::LOGIN_STAGE_GOT_CODE, None, None),
673        P::Exchanging => (acct::LoginStage::LOGIN_STAGE_EXCHANGING, None, None),
674    };
675    acct::LoginEvent {
676        stage: stage.into(),
677        message,
678        error: None,
679        user_code,
680        ..Default::default()
681    }
682}
683
684// ═══════════════════════ AmpService ══════════════════════════════
685
686struct AmpServiceImpl(Arc<AppState>);
687
688fn to_pb_summary(s: &internal_threads::AmpThreadSummary) -> amp_pb::ThreadSummary {
689    amp_pb::ThreadSummary {
690        id: s.id.clone(),
691        created: s.created,
692        title: s.title.clone(),
693        message_count: clamp_to_u32(s.message_count),
694        agent_mode: s.agent_mode.clone(),
695        last_model: s.last_model.clone(),
696        total_input_tokens: s.total_input_tokens,
697        total_output_tokens: s.total_output_tokens,
698        file_size_bytes: s.file_size_bytes,
699        ..Default::default()
700    }
701}
702
703fn to_pb_content_block(b: internal_threads::AmpContentBlock) -> amp_pb::ContentBlock {
704    use amp_pb::content_block::Block;
705    let block = Some(match b {
706        internal_threads::AmpContentBlock::Text { text } => Block::Text(text),
707        internal_threads::AmpContentBlock::Thinking { thinking } => Block::Thinking(thinking),
708        internal_threads::AmpContentBlock::ToolUse { id, name, input } => {
709            Block::ToolUse(Box::new(amp_pb::ToolUse {
710                id,
711                name,
712                input: MessageField::some(json_to_pb_struct(input)),
713                ..Default::default()
714            }))
715        }
716        internal_threads::AmpContentBlock::ToolResult { tool_use_id, run } => {
717            Block::ToolResult(Box::new(amp_pb::ToolResult {
718                tool_use_id,
719                run: MessageField::some(amp_pb::ToolRun {
720                    status: run.status,
721                    result: run.result.map(json_to_pb_value).into(),
722                    error: run.error.map(json_to_pb_value).into(),
723                    ..Default::default()
724                }),
725                ..Default::default()
726            }))
727        }
728        internal_threads::AmpContentBlock::Unknown { original_type } => {
729            Block::UnknownType(original_type.unwrap_or_default())
730        }
731    });
732    amp_pb::ContentBlock {
733        block,
734        ..Default::default()
735    }
736}
737
738fn to_pb_message(m: internal_threads::AmpMessage) -> amp_pb::Message {
739    amp_pb::Message {
740        role: m.role,
741        message_id: m.message_id,
742        content: m.content.into_iter().map(to_pb_content_block).collect(),
743        usage: m
744            .usage
745            .map(|u| amp_pb::Usage {
746                model: u.model,
747                input_tokens: u.input_tokens,
748                output_tokens: u.output_tokens,
749                cache_creation_input_tokens: u.cache_creation_input_tokens,
750                cache_read_input_tokens: u.cache_read_input_tokens,
751                total_input_tokens: u.total_input_tokens,
752                ..Default::default()
753            })
754            .into(),
755        state: m
756            .state
757            .map(|s| amp_pb::MessageState {
758                state_type: s.state_type,
759                stop_reason: s.stop_reason,
760                ..Default::default()
761            })
762            .into(),
763        ..Default::default()
764    }
765}
766
767fn to_pb_detail(d: internal_threads::AmpThreadDetail) -> amp_pb::ThreadDetail {
768    amp_pb::ThreadDetail {
769        id: d.id,
770        v: d.v,
771        created: d.created,
772        title: d.title,
773        agent_mode: d.agent_mode,
774        messages: d.messages.into_iter().map(to_pb_message).collect(),
775        relationships: d
776            .relationships
777            .into_iter()
778            .map(|r| amp_pb::Relationship {
779                thread_id: r.thread_id,
780                rel_type: r.rel_type,
781                role: r.role,
782                ..Default::default()
783            })
784            .collect(),
785        env: d.env.map(json_to_pb_struct).into(),
786        ..Default::default()
787    }
788}
789
790impl amp_pb::AmpService for AmpServiceImpl {
791    async fn list_threads(
792        &self,
793        ctx: Context,
794        request: OwnedView<amp_pb::ListThreadsRequestView<'static>>,
795    ) -> Result<(amp_pb::ListThreadsResponse, Context), ConnectError> {
796        let req = request.to_owned_message();
797        let all = self.0.amp_threads.list();
798        let want_messages = req.has_messages.unwrap_or(true);
799        let filtered: Vec<_> = all
800            .iter()
801            .filter(|s| !want_messages || s.message_count > 0)
802            .collect();
803        let total = filtered.len();
804        let limit = usize::try_from(req.limit.unwrap_or(50))
805            .unwrap_or(50)
806            .min(200);
807        let offset = usize::try_from(req.offset.unwrap_or(0))
808            .unwrap_or(0)
809            .min(total);
810        let threads = filtered
811            .into_iter()
812            .skip(offset)
813            .take(limit)
814            .map(to_pb_summary)
815            .collect();
816        Ok((
817            amp_pb::ListThreadsResponse {
818                threads,
819                total: clamp_to_u32(total),
820                ..Default::default()
821            },
822            ctx,
823        ))
824    }
825
826    async fn get_thread(
827        &self,
828        ctx: Context,
829        request: OwnedView<amp_pb::GetThreadRequestView<'static>>,
830    ) -> Result<(amp_pb::GetThreadResponse, Context), ConnectError> {
831        let req = request.to_owned_message();
832        if !internal_threads::is_valid_thread_id(&req.id) {
833            return Err(ConnectError::invalid_argument("invalid thread ID format"));
834        }
835        let path = internal_threads::threads_dir().join(format!("{}.json", req.id));
836        #[allow(clippy::result_large_err)]
837        let detail = tokio::task::spawn_blocking(move || {
838            if !path.exists() {
839                return Err(ConnectError::not_found("thread not found"));
840            }
841            internal_threads::parse_detail(&path).map_err(|e| {
842                tracing::error!(error = %e, "failed to parse amp thread");
843                ConnectError::internal("failed to parse thread")
844            })
845        })
846        .await
847        .map_err(|e| ConnectError::internal(format!("spawn_blocking failed: {e}")))??;
848        Ok((
849            amp_pb::GetThreadResponse {
850                thread: MessageField::some(to_pb_detail(detail)),
851                ..Default::default()
852            },
853            ctx,
854        ))
855    }
856
857    async fn inject_url(
858        &self,
859        ctx: Context,
860        request: OwnedView<amp_pb::InjectUrlRequestView<'static>>,
861    ) -> Result<(amp_pb::InjectUrlResponse, Context), ConnectError> {
862        let req = request.to_owned_message();
863        let snapshot = self.0.config.load();
864        let resolved_url =
865            snapshot
866                .amp
867                .resolve_url(req.url.as_deref(), &snapshot.host, snapshot.port);
868        let settings_path = byokey_config::AmpConfig::default_settings_path()
869            .ok_or_else(|| ConnectError::internal("cannot determine HOME directory"))?;
870
871        let amp_cfg = snapshot.amp.clone();
872        let settings_path_for_spawn = settings_path.clone();
873        let resolved_url_for_spawn = resolved_url.clone();
874        #[allow(clippy::result_large_err)]
875        let extras = tokio::task::spawn_blocking(move || {
876            amp_cfg
877                .inject(&resolved_url_for_spawn, &settings_path_for_spawn)
878                .map_err(|e| ConnectError::internal(format!("inject failed: {e}")))
879        })
880        .await
881        .map_err(|e| ConnectError::internal(format!("spawn_blocking failed: {e}")))??;
882
883        Ok((
884            amp_pb::InjectUrlResponse {
885                resolved_url,
886                settings_path: settings_path.display().to_string(),
887                extras_merged: clamp_to_u32(extras),
888                ..Default::default()
889            },
890            ctx,
891        ))
892    }
893}
894
895#[cfg(test)]
896mod tests {
897    use super::*;
898    use byokey_auth::flow::LoginProgress;
899
900    #[test]
901    fn progress_to_pb_started() {
902        let ev = progress_to_pb(&LoginProgress::Started);
903        assert_eq!(ev.stage, acct::LoginStage::LOGIN_STAGE_STARTED);
904        assert!(ev.message.is_none());
905        assert!(ev.user_code.is_none());
906    }
907
908    #[test]
909    fn progress_to_pb_opened_browser_auth_code() {
910        let ev = progress_to_pb(&LoginProgress::OpenedBrowser {
911            url: "https://example.com/auth".into(),
912            user_code: None,
913        });
914        assert_eq!(ev.stage, acct::LoginStage::LOGIN_STAGE_OPENED_BROWSER);
915        assert_eq!(ev.message.as_deref(), Some("https://example.com/auth"));
916        assert!(ev.user_code.is_none());
917    }
918
919    #[test]
920    fn progress_to_pb_opened_browser_device_code() {
921        let ev = progress_to_pb(&LoginProgress::OpenedBrowser {
922            url: "https://github.com/login/device".into(),
923            user_code: Some("ABCD-1234".into()),
924        });
925        assert_eq!(ev.stage, acct::LoginStage::LOGIN_STAGE_OPENED_BROWSER);
926        assert_eq!(
927            ev.message.as_deref(),
928            Some("https://github.com/login/device")
929        );
930        assert_eq!(ev.user_code.as_deref(), Some("ABCD-1234"));
931    }
932
933    #[test]
934    fn progress_to_pb_got_code() {
935        let ev = progress_to_pb(&LoginProgress::GotCode);
936        assert_eq!(ev.stage, acct::LoginStage::LOGIN_STAGE_GOT_CODE);
937    }
938
939    #[test]
940    fn progress_to_pb_exchanging() {
941        let ev = progress_to_pb(&LoginProgress::Exchanging);
942        assert_eq!(ev.stage, acct::LoginStage::LOGIN_STAGE_EXCHANGING);
943    }
944}