1use std::sync::Arc;
9
10use buffa::MessageField;
11use buffa::view::OwnedView;
12use buffa_types::google::protobuf::value::Kind;
13use buffa_types::google::protobuf::{ListValue, NullValue, Struct, Value};
14use connectrpc::{ConnectError, Context, Router as ConnectRouter};
15use serde_json::Value as JsonValue;
16
17use byokey_proto::byokey::accounts as acct;
18use byokey_proto::byokey::amp as amp_pb;
19use byokey_proto::byokey::status as stat;
20
21use crate::AppState;
22use crate::handler::amp::threads as internal_threads;
23
24#[must_use]
28pub fn build_router(state: Arc<AppState>) -> ConnectRouter {
29 use acct::AccountsServiceExt as _;
30 use amp_pb::AmpServiceExt as _;
31 use stat::StatusServiceExt as _;
32
33 let router = ConnectRouter::new();
34 let router = Arc::new(StatusServiceImpl(state.clone())).register(router);
35 let router = Arc::new(AccountsServiceImpl(state.clone())).register(router);
36 Arc::new(AmpServiceImpl(state)).register(router)
37}
38
39fn byok_to_connect_error(e: &byokey_types::ByokError) -> ConnectError {
42 use byokey_types::ByokError;
43 let msg = e.to_string();
44 match e {
45 ByokError::Auth(_) | ByokError::TokenNotFound(_) | ByokError::TokenExpired(_) => {
46 ConnectError::unauthenticated(msg)
47 }
48 ByokError::UnsupportedModel(_) => ConnectError::not_found(msg),
49 ByokError::UnsupportedProvider(_) | ByokError::Translation(_) => {
50 ConnectError::invalid_argument(msg)
51 }
52 ByokError::ProviderUnavailable(_) => ConnectError::unavailable(msg),
53 _ => ConnectError::internal(msg),
54 }
55}
56
57fn json_to_pb_value(v: JsonValue) -> Value {
58 let kind = match v {
59 JsonValue::Null => Kind::NullValue(NullValue::NULL_VALUE.into()),
60 JsonValue::Bool(b) => Kind::BoolValue(b),
61 JsonValue::Number(n) => Kind::NumberValue(n.as_f64().unwrap_or(0.0)),
62 JsonValue::String(s) => Kind::StringValue(s),
63 JsonValue::Array(arr) => {
64 let values = arr.into_iter().map(json_to_pb_value).collect();
65 Kind::ListValue(Box::new(ListValue {
66 values,
67 ..Default::default()
68 }))
69 }
70 JsonValue::Object(map) => {
71 let fields = map
72 .into_iter()
73 .map(|(k, v)| (k, json_to_pb_value(v)))
74 .collect();
75 Kind::StructValue(Box::new(Struct {
76 fields,
77 ..Default::default()
78 }))
79 }
80 };
81 Value {
82 kind: Some(kind),
83 ..Default::default()
84 }
85}
86
87fn json_to_pb_struct(v: JsonValue) -> Struct {
88 if let JsonValue::Object(map) = v {
89 let fields = map
90 .into_iter()
91 .map(|(k, v)| (k, json_to_pb_value(v)))
92 .collect();
93 Struct {
94 fields,
95 ..Default::default()
96 }
97 } else {
98 Struct::default()
99 }
100}
101
102fn clamp_to_u32(n: usize) -> u32 {
103 u32::try_from(n).unwrap_or(u32::MAX)
104}
105
106const THIRTY_DAYS_SECS: i64 = 30 * 24 * 3600;
107const MAX_USAGE_RANGE_SECS: i64 = 365 * 24 * 3600;
110
111fn policy_strategy_to_proto(kind: byokey_config::PolicyStrategyKind) -> stat::RoutingStrategy {
112 use byokey_config::PolicyStrategyKind as K;
113 match kind {
114 K::RoundRobin => stat::RoutingStrategy::ROUTING_STRATEGY_ROUND_ROBIN,
115 K::WeightedRoundRobin => stat::RoutingStrategy::ROUTING_STRATEGY_WEIGHTED_ROUND_ROBIN,
116 K::Random => stat::RoutingStrategy::ROUTING_STRATEGY_RANDOM,
117 K::WeightedRandom => stat::RoutingStrategy::ROUTING_STRATEGY_WEIGHTED_RANDOM,
118 K::Priority => stat::RoutingStrategy::ROUTING_STRATEGY_PRIORITY,
119 }
120}
121
122#[allow(clippy::cast_possible_wrap)]
123fn now_seconds() -> i64 {
124 std::time::SystemTime::now()
125 .duration_since(std::time::UNIX_EPOCH)
126 .unwrap_or_default()
127 .as_secs() as i64
128}
129
130struct StatusServiceImpl(Arc<AppState>);
133
134impl stat::StatusService for StatusServiceImpl {
135 async fn get_status(
136 &self,
137 ctx: Context,
138 _: OwnedView<stat::GetStatusRequestView<'static>>,
139 ) -> Result<(stat::GetStatusResponse, Context), ConnectError> {
140 let snapshot = self.0.config.load();
141 let server = stat::ServerInfo {
142 host: snapshot.host.clone(),
143 port: u32::from(snapshot.port),
144 ..Default::default()
145 };
146
147 let mut providers = Vec::new();
148 for pid in byokey_types::ProviderId::all() {
149 let cfg = snapshot.providers.get(pid);
150 let has_key = cfg.is_some_and(|c| c.api_key.is_some() || !c.api_keys.is_empty());
151 let auth = if has_key || self.0.auth.is_authenticated(pid).await {
152 stat::AuthStatus::AUTH_STATUS_VALID
153 } else {
154 let accts = self.0.auth.list_accounts(pid).await.unwrap_or_default();
155 if accts.is_empty() {
156 stat::AuthStatus::AUTH_STATUS_NOT_CONFIGURED
157 } else {
158 stat::AuthStatus::AUTH_STATUS_EXPIRED
159 }
160 };
161 providers.push(stat::ProviderStatus {
162 id: pid.to_string(),
163 display_name: pid.display_name().to_string(),
164 enabled: cfg.is_none_or(|c| c.enabled),
165 auth_status: auth.into(),
166 models_count: clamp_to_u32(byokey_provider::models_for_provider(pid).len()),
167 ..Default::default()
168 });
169 }
170 Ok((
171 stat::GetStatusResponse {
172 server: server.into(),
173 providers,
174 ..Default::default()
175 },
176 ctx,
177 ))
178 }
179
180 async fn get_usage(
181 &self,
182 ctx: Context,
183 _: OwnedView<stat::GetUsageRequestView<'static>>,
184 ) -> Result<(stat::GetUsageResponse, Context), ConnectError> {
185 let s = self.0.usage.snapshot();
186 let models = s
187 .models
188 .into_iter()
189 .map(|(k, m)| {
190 (
191 k,
192 stat::ModelStats {
193 requests: m.requests,
194 success: m.success,
195 failure: m.failure,
196 input_tokens: m.input_tokens,
197 output_tokens: m.output_tokens,
198 ..Default::default()
199 },
200 )
201 })
202 .collect();
203 Ok((
204 stat::GetUsageResponse {
205 total_requests: s.total_requests,
206 success_requests: s.success_requests,
207 failure_requests: s.failure_requests,
208 input_tokens: s.input_tokens,
209 output_tokens: s.output_tokens,
210 models,
211 ..Default::default()
212 },
213 ctx,
214 ))
215 }
216
217 async fn get_usage_history(
218 &self,
219 ctx: Context,
220 request: OwnedView<stat::GetUsageHistoryRequestView<'static>>,
221 ) -> Result<(stat::GetUsageHistoryResponse, Context), ConnectError> {
222 let req = request.to_owned_message();
223 let Some(store) = self.0.usage.store() else {
224 let to = now_seconds();
225 return Ok((
226 stat::GetUsageHistoryResponse {
227 from: to - 86400,
228 to,
229 bucket_seconds: 3600,
230 error: Some("no persistent usage store configured".into()),
231 ..Default::default()
232 },
233 ctx,
234 ));
235 };
236 let to = req.to.unwrap_or_else(now_seconds);
237 let from = req.from.unwrap_or(to - 86400);
238 let range = to - from;
239 let bs = if range <= 86400 {
240 3600
241 } else if range <= 86400 * 7 {
242 21600
243 } else {
244 86400
245 };
246 let buckets = store
247 .query(from, to, req.model.as_deref(), bs)
248 .await
249 .map_err(|e| ConnectError::internal(e.to_string()))?
250 .into_iter()
251 .map(|b| stat::UsageBucket {
252 period_start: b.period_start,
253 model: b.model,
254 request_count: b.request_count,
255 input_tokens: b.input_tokens,
256 output_tokens: b.output_tokens,
257 ..Default::default()
258 })
259 .collect();
260 Ok((
261 stat::GetUsageHistoryResponse {
262 from,
263 to,
264 bucket_seconds: bs,
265 buckets,
266 error: None,
267 ..Default::default()
268 },
269 ctx,
270 ))
271 }
272
273 async fn get_usage_by_account(
274 &self,
275 ctx: Context,
276 request: OwnedView<stat::GetUsageByAccountRequestView<'static>>,
277 ) -> Result<(stat::GetUsageByAccountResponse, Context), ConnectError> {
278 let req = request.to_owned_message();
279 let Some(store) = self.0.usage.store() else {
280 return Ok((
281 stat::GetUsageByAccountResponse {
282 rows: Vec::new(),
283 error: Some("no persistent usage store configured".into()),
284 ..Default::default()
285 },
286 ctx,
287 ));
288 };
289 let to = req.to.unwrap_or_else(now_seconds);
290 let from = req
291 .from
292 .unwrap_or_else(|| to.saturating_sub(THIRTY_DAYS_SECS));
293 if from > to {
294 return Err(ConnectError::invalid_argument(
295 "from must be less than or equal to to",
296 ));
297 }
298 if to.saturating_sub(from) > MAX_USAGE_RANGE_SECS {
301 return Err(ConnectError::invalid_argument(format!(
302 "requested range exceeds maximum of {} days",
303 MAX_USAGE_RANGE_SECS / 86_400
304 )));
305 }
306 let totals = store
307 .totals_by_account(Some(from), Some(to))
308 .await
309 .map_err(|e| ConnectError::internal(e.to_string()))?;
310 let rows = totals
311 .into_iter()
312 .map(|t| stat::AccountUsageRow {
313 provider: t.provider,
314 account_id: t.account_id,
315 model: t.model,
316 request_count: t.request_count,
317 success_count: t.success_count,
318 input_tokens: t.input_tokens,
319 output_tokens: t.output_tokens,
320 ..Default::default()
321 })
322 .collect();
323 Ok((
324 stat::GetUsageByAccountResponse {
325 rows,
326 error: None,
327 ..Default::default()
328 },
329 ctx,
330 ))
331 }
332
333 async fn list_routing_policies(
334 &self,
335 ctx: Context,
336 _: OwnedView<stat::ListRoutingPoliciesRequestView<'static>>,
337 ) -> Result<(stat::ListRoutingPoliciesResponse, Context), ConnectError> {
338 let snapshot = self.0.config.load();
339 let policies = snapshot
340 .routing_policies
341 .iter()
342 .map(|entry| stat::RoutingPolicy {
343 provider: entry.provider.to_string(),
344 family: entry.family.clone().unwrap_or_default(),
345 strategy: policy_strategy_to_proto(entry.strategy).into(),
346 accounts: entry.accounts.clone(),
347 weights: entry.weights.clone(),
348 ..Default::default()
349 })
350 .collect();
351 Ok((
352 stat::ListRoutingPoliciesResponse {
353 policies,
354 ..Default::default()
355 },
356 ctx,
357 ))
358 }
359
360 async fn set_routing_policy(
361 &self,
362 _ctx: Context,
363 _: OwnedView<stat::SetRoutingPolicyRequestView<'static>>,
364 ) -> Result<(stat::SetRoutingPolicyResponse, Context), ConnectError> {
365 Err(ConnectError::unimplemented(
369 "SetRoutingPolicy is not yet implemented; edit settings.json directly",
370 ))
371 }
372
373 async fn get_rate_limits(
374 &self,
375 ctx: Context,
376 _: OwnedView<stat::GetRateLimitsRequestView<'static>>,
377 ) -> Result<(stat::GetRateLimitsResponse, Context), ConnectError> {
378 let all = self.0.ratelimits.all();
379 let mut by_prov: std::collections::HashMap<
380 byokey_types::ProviderId,
381 Vec<stat::AccountRateLimit>,
382 > = std::collections::HashMap::new();
383 for ((prov, aid), snap) in all {
384 by_prov
385 .entry(prov)
386 .or_default()
387 .push(stat::AccountRateLimit {
388 account_id: aid,
389 snapshot: MessageField::some(stat::RateLimitSnapshot {
390 headers: snap.headers,
391 captured_at: snap.captured_at,
392 ..Default::default()
393 }),
394 ..Default::default()
395 });
396 }
397 let providers = byokey_types::ProviderId::all()
398 .iter()
399 .filter_map(|pid| {
400 let accts = by_prov.remove(pid)?;
401 Some(stat::ProviderRateLimits {
402 id: pid.to_string(),
403 display_name: pid.display_name().to_string(),
404 accounts: accts,
405 ..Default::default()
406 })
407 })
408 .collect();
409 Ok((
410 stat::GetRateLimitsResponse {
411 providers,
412 ..Default::default()
413 },
414 ctx,
415 ))
416 }
417}
418
419struct AccountsServiceImpl(Arc<AppState>);
422
423impl acct::AccountsService for AccountsServiceImpl {
424 async fn list_accounts(
425 &self,
426 ctx: Context,
427 _: OwnedView<acct::ListAccountsRequestView<'static>>,
428 ) -> Result<(acct::ListAccountsResponse, Context), ConnectError> {
429 let mut providers = Vec::new();
430 for pid in byokey_types::ProviderId::all() {
431 let infos = self.0.auth.list_accounts(pid).await.unwrap_or_default();
432 let tokens = self.0.auth.get_all_tokens(pid).await.unwrap_or_default();
433 let accounts = infos
434 .iter()
435 .map(|info| {
436 let (ts, exp) = match tokens.iter().find(|(id, _)| id == &info.account_id) {
437 Some((_, tok)) => {
438 let s = match tok.state() {
439 byokey_types::TokenState::Valid => {
440 acct::TokenState::TOKEN_STATE_VALID
441 }
442 byokey_types::TokenState::Expired => {
443 acct::TokenState::TOKEN_STATE_EXPIRED
444 }
445 byokey_types::TokenState::Invalid => {
446 acct::TokenState::TOKEN_STATE_INVALID
447 }
448 };
449 (s, tok.expires_at)
450 }
451 None => (acct::TokenState::TOKEN_STATE_INVALID, None),
452 };
453 acct::AccountDetail {
454 account_id: info.account_id.clone(),
455 label: info.label.clone(),
456 is_active: info.is_active,
457 token_state: ts.into(),
458 expires_at: exp,
459 ..Default::default()
460 }
461 })
462 .collect();
463 providers.push(acct::ProviderAccounts {
464 id: pid.to_string(),
465 display_name: pid.display_name().to_string(),
466 accounts,
467 ..Default::default()
468 });
469 }
470 Ok((
471 acct::ListAccountsResponse {
472 providers,
473 ..Default::default()
474 },
475 ctx,
476 ))
477 }
478
479 async fn remove_account(
480 &self,
481 ctx: Context,
482 request: OwnedView<acct::RemoveAccountRequestView<'static>>,
483 ) -> Result<(acct::RemoveAccountResponse, Context), ConnectError> {
484 let req = request.to_owned_message();
485 let pid: byokey_types::ProviderId = req
486 .provider
487 .parse()
488 .map_err(|e: byokey_types::ByokError| byok_to_connect_error(&e))?;
489 self.0
490 .auth
491 .remove_token_for(&pid, &req.account_id)
492 .await
493 .map_err(|e| byok_to_connect_error(&e))?;
494 Ok((acct::RemoveAccountResponse::default(), ctx))
495 }
496
497 async fn activate_account(
498 &self,
499 ctx: Context,
500 request: OwnedView<acct::ActivateAccountRequestView<'static>>,
501 ) -> Result<(acct::ActivateAccountResponse, Context), ConnectError> {
502 let req = request.to_owned_message();
503 let pid: byokey_types::ProviderId = req
504 .provider
505 .parse()
506 .map_err(|e: byokey_types::ByokError| byok_to_connect_error(&e))?;
507 self.0
508 .auth
509 .set_active_account(&pid, &req.account_id)
510 .await
511 .map_err(|e| byok_to_connect_error(&e))?;
512 Ok((acct::ActivateAccountResponse::default(), ctx))
513 }
514
515 async fn add_api_key(
516 &self,
517 ctx: Context,
518 request: OwnedView<acct::AddApiKeyRequestView<'static>>,
519 ) -> Result<(acct::AddApiKeyResponse, Context), ConnectError> {
520 let req = request.to_owned_message();
521 let pid: byokey_types::ProviderId = req
522 .provider
523 .parse()
524 .map_err(|e: byokey_types::ByokError| byok_to_connect_error(&e))?;
525 if req.api_key.trim().is_empty() {
526 return Err(ConnectError::invalid_argument("api_key cannot be empty"));
527 }
528 if req.api_key.len() > byokey_types::MAX_API_KEY_BYTES {
529 return Err(ConnectError::invalid_argument(format!(
530 "api_key exceeds maximum length of {} bytes",
531 byokey_types::MAX_API_KEY_BYTES
532 )));
533 }
534 let account_id = req
535 .account_id
536 .unwrap_or_else(|| byokey_types::DEFAULT_ACCOUNT.to_string());
537 let token = byokey_types::OAuthToken {
538 access_token: req.api_key.trim().to_string(),
539 refresh_token: None,
540 expires_at: None,
541 token_type: Some("api-key".to_string()),
542 };
543 self.0
544 .auth
545 .save_token_for(&pid, &account_id, req.label.as_deref(), token)
546 .await
547 .map_err(|e| byok_to_connect_error(&e))?;
548 Ok((
549 acct::AddApiKeyResponse {
550 account_id,
551 ..Default::default()
552 },
553 ctx,
554 ))
555 }
556
557 async fn import_claude_code(
558 &self,
559 ctx: Context,
560 request: OwnedView<acct::ImportClaudeCodeRequestView<'static>>,
561 ) -> Result<(acct::ImportClaudeCodeResponse, Context), ConnectError> {
562 let req = request.to_owned_message();
563 let token = byokey_auth::provider::claude_code::load_token()
564 .await
565 .map_err(|e| byok_to_connect_error(&e))?
566 .ok_or_else(|| {
567 ConnectError::failed_precondition(
568 "no Claude Code credentials found — is Claude Code logged in on this machine?",
569 )
570 })?;
571 let pid = byokey_types::ProviderId::Claude;
572 let account_id = req
573 .account_id
574 .unwrap_or_else(|| byokey_types::CLAUDE_CODE_ACCOUNT.to_string());
575 let label = req.label.unwrap_or_else(|| "Claude Code".to_string());
576 self.0
577 .auth
578 .save_token_for(&pid, &account_id, Some(label.as_str()), token)
579 .await
580 .map_err(|e| byok_to_connect_error(&e))?;
581 Ok((
582 acct::ImportClaudeCodeResponse {
583 account_id,
584 ..Default::default()
585 },
586 ctx,
587 ))
588 }
589
590 async fn login(
591 &self,
592 ctx: Context,
593 request: OwnedView<acct::LoginRequestView<'static>>,
594 ) -> Result<
595 (
596 std::pin::Pin<
597 Box<dyn futures_util::Stream<Item = Result<acct::LoginEvent, ConnectError>> + Send>,
598 >,
599 Context,
600 ),
601 ConnectError,
602 > {
603 use futures_util::StreamExt as _;
604 use tokio_stream::wrappers::ReceiverStream;
605
606 let req = request.to_owned_message();
607 let pid: byokey_types::ProviderId = req
608 .provider
609 .parse()
610 .map_err(|e: byokey_types::ByokError| byok_to_connect_error(&e))?;
611 let account = req.account_id;
612
613 let (progress_tx, progress_rx) =
614 tokio::sync::mpsc::channel::<byokey_auth::flow::LoginProgress>(8);
615 let (event_tx, event_rx) =
616 tokio::sync::mpsc::channel::<Result<acct::LoginEvent, ConnectError>>(16);
617
618 let auth = self.0.auth.clone();
619 let event_tx_drive = event_tx.clone();
620 tokio::spawn(async move {
621 let mut progress_rx = progress_rx;
622 let account_ref = account.as_deref();
623 let login_fut =
624 byokey_auth::flow::login_with_events(&pid, &auth, account_ref, Some(progress_tx));
625 tokio::pin!(login_fut);
626
627 loop {
628 tokio::select! {
629 biased;
630 Some(p) = progress_rx.recv() => {
631 let ev = progress_to_pb(&p);
632 if event_tx_drive.send(Ok(ev)).await.is_err() { return; }
633 }
634 res = &mut login_fut => {
635 while let Ok(p) = progress_rx.try_recv() {
637 let ev = progress_to_pb(&p);
638 let _ = event_tx_drive.send(Ok(ev)).await;
639 }
640 let terminal = match res {
641 Ok(()) => acct::LoginEvent {
642 stage: acct::LoginStage::LOGIN_STAGE_DONE.into(),
643 ..Default::default()
644 },
645 Err(e) => acct::LoginEvent {
646 stage: acct::LoginStage::LOGIN_STAGE_FAILED.into(),
647 error: Some(e.to_string()),
648 ..Default::default()
649 },
650 };
651 let _ = event_tx_drive.send(Ok(terminal)).await;
652 return;
653 }
654 }
655 }
656 });
657
658 let stream = ReceiverStream::new(event_rx).boxed();
659 Ok((stream, ctx))
660 }
661}
662
663fn progress_to_pb(p: &byokey_auth::flow::LoginProgress) -> acct::LoginEvent {
664 use byokey_auth::flow::LoginProgress as P;
665 let (stage, message, user_code) = match p {
666 P::Started => (acct::LoginStage::LOGIN_STAGE_STARTED, None, None),
667 P::OpenedBrowser { url, user_code } => (
668 acct::LoginStage::LOGIN_STAGE_OPENED_BROWSER,
669 Some(url.clone()),
670 user_code.clone(),
671 ),
672 P::GotCode => (acct::LoginStage::LOGIN_STAGE_GOT_CODE, None, None),
673 P::Exchanging => (acct::LoginStage::LOGIN_STAGE_EXCHANGING, None, None),
674 };
675 acct::LoginEvent {
676 stage: stage.into(),
677 message,
678 error: None,
679 user_code,
680 ..Default::default()
681 }
682}
683
684struct AmpServiceImpl(Arc<AppState>);
687
688fn to_pb_summary(s: &internal_threads::AmpThreadSummary) -> amp_pb::ThreadSummary {
689 amp_pb::ThreadSummary {
690 id: s.id.clone(),
691 created: s.created,
692 title: s.title.clone(),
693 message_count: clamp_to_u32(s.message_count),
694 agent_mode: s.agent_mode.clone(),
695 last_model: s.last_model.clone(),
696 total_input_tokens: s.total_input_tokens,
697 total_output_tokens: s.total_output_tokens,
698 file_size_bytes: s.file_size_bytes,
699 ..Default::default()
700 }
701}
702
703fn to_pb_content_block(b: internal_threads::AmpContentBlock) -> amp_pb::ContentBlock {
704 use amp_pb::content_block::Block;
705 let block = Some(match b {
706 internal_threads::AmpContentBlock::Text { text } => Block::Text(text),
707 internal_threads::AmpContentBlock::Thinking { thinking } => Block::Thinking(thinking),
708 internal_threads::AmpContentBlock::ToolUse { id, name, input } => {
709 Block::ToolUse(Box::new(amp_pb::ToolUse {
710 id,
711 name,
712 input: MessageField::some(json_to_pb_struct(input)),
713 ..Default::default()
714 }))
715 }
716 internal_threads::AmpContentBlock::ToolResult { tool_use_id, run } => {
717 Block::ToolResult(Box::new(amp_pb::ToolResult {
718 tool_use_id,
719 run: MessageField::some(amp_pb::ToolRun {
720 status: run.status,
721 result: run.result.map(json_to_pb_value).into(),
722 error: run.error.map(json_to_pb_value).into(),
723 ..Default::default()
724 }),
725 ..Default::default()
726 }))
727 }
728 internal_threads::AmpContentBlock::Unknown { original_type } => {
729 Block::UnknownType(original_type.unwrap_or_default())
730 }
731 });
732 amp_pb::ContentBlock {
733 block,
734 ..Default::default()
735 }
736}
737
738fn to_pb_message(m: internal_threads::AmpMessage) -> amp_pb::Message {
739 amp_pb::Message {
740 role: m.role,
741 message_id: m.message_id,
742 content: m.content.into_iter().map(to_pb_content_block).collect(),
743 usage: m
744 .usage
745 .map(|u| amp_pb::Usage {
746 model: u.model,
747 input_tokens: u.input_tokens,
748 output_tokens: u.output_tokens,
749 cache_creation_input_tokens: u.cache_creation_input_tokens,
750 cache_read_input_tokens: u.cache_read_input_tokens,
751 total_input_tokens: u.total_input_tokens,
752 ..Default::default()
753 })
754 .into(),
755 state: m
756 .state
757 .map(|s| amp_pb::MessageState {
758 state_type: s.state_type,
759 stop_reason: s.stop_reason,
760 ..Default::default()
761 })
762 .into(),
763 ..Default::default()
764 }
765}
766
767fn to_pb_detail(d: internal_threads::AmpThreadDetail) -> amp_pb::ThreadDetail {
768 amp_pb::ThreadDetail {
769 id: d.id,
770 v: d.v,
771 created: d.created,
772 title: d.title,
773 agent_mode: d.agent_mode,
774 messages: d.messages.into_iter().map(to_pb_message).collect(),
775 relationships: d
776 .relationships
777 .into_iter()
778 .map(|r| amp_pb::Relationship {
779 thread_id: r.thread_id,
780 rel_type: r.rel_type,
781 role: r.role,
782 ..Default::default()
783 })
784 .collect(),
785 env: d.env.map(json_to_pb_struct).into(),
786 ..Default::default()
787 }
788}
789
790impl amp_pb::AmpService for AmpServiceImpl {
791 async fn list_threads(
792 &self,
793 ctx: Context,
794 request: OwnedView<amp_pb::ListThreadsRequestView<'static>>,
795 ) -> Result<(amp_pb::ListThreadsResponse, Context), ConnectError> {
796 let req = request.to_owned_message();
797 let all = self.0.amp_threads.list();
798 let want_messages = req.has_messages.unwrap_or(true);
799 let filtered: Vec<_> = all
800 .iter()
801 .filter(|s| !want_messages || s.message_count > 0)
802 .collect();
803 let total = filtered.len();
804 let limit = usize::try_from(req.limit.unwrap_or(50))
805 .unwrap_or(50)
806 .min(200);
807 let offset = usize::try_from(req.offset.unwrap_or(0))
808 .unwrap_or(0)
809 .min(total);
810 let threads = filtered
811 .into_iter()
812 .skip(offset)
813 .take(limit)
814 .map(to_pb_summary)
815 .collect();
816 Ok((
817 amp_pb::ListThreadsResponse {
818 threads,
819 total: clamp_to_u32(total),
820 ..Default::default()
821 },
822 ctx,
823 ))
824 }
825
826 async fn get_thread(
827 &self,
828 ctx: Context,
829 request: OwnedView<amp_pb::GetThreadRequestView<'static>>,
830 ) -> Result<(amp_pb::GetThreadResponse, Context), ConnectError> {
831 let req = request.to_owned_message();
832 if !internal_threads::is_valid_thread_id(&req.id) {
833 return Err(ConnectError::invalid_argument("invalid thread ID format"));
834 }
835 let path = internal_threads::threads_dir().join(format!("{}.json", req.id));
836 #[allow(clippy::result_large_err)]
837 let detail = tokio::task::spawn_blocking(move || {
838 if !path.exists() {
839 return Err(ConnectError::not_found("thread not found"));
840 }
841 internal_threads::parse_detail(&path).map_err(|e| {
842 tracing::error!(error = %e, "failed to parse amp thread");
843 ConnectError::internal("failed to parse thread")
844 })
845 })
846 .await
847 .map_err(|e| ConnectError::internal(format!("spawn_blocking failed: {e}")))??;
848 Ok((
849 amp_pb::GetThreadResponse {
850 thread: MessageField::some(to_pb_detail(detail)),
851 ..Default::default()
852 },
853 ctx,
854 ))
855 }
856
857 async fn inject_url(
858 &self,
859 ctx: Context,
860 request: OwnedView<amp_pb::InjectUrlRequestView<'static>>,
861 ) -> Result<(amp_pb::InjectUrlResponse, Context), ConnectError> {
862 let req = request.to_owned_message();
863 let snapshot = self.0.config.load();
864 let resolved_url =
865 snapshot
866 .amp
867 .resolve_url(req.url.as_deref(), &snapshot.host, snapshot.port);
868 let settings_path = byokey_config::AmpConfig::default_settings_path()
869 .ok_or_else(|| ConnectError::internal("cannot determine HOME directory"))?;
870
871 let amp_cfg = snapshot.amp.clone();
872 let settings_path_for_spawn = settings_path.clone();
873 let resolved_url_for_spawn = resolved_url.clone();
874 #[allow(clippy::result_large_err)]
875 let extras = tokio::task::spawn_blocking(move || {
876 amp_cfg
877 .inject(&resolved_url_for_spawn, &settings_path_for_spawn)
878 .map_err(|e| ConnectError::internal(format!("inject failed: {e}")))
879 })
880 .await
881 .map_err(|e| ConnectError::internal(format!("spawn_blocking failed: {e}")))??;
882
883 Ok((
884 amp_pb::InjectUrlResponse {
885 resolved_url,
886 settings_path: settings_path.display().to_string(),
887 extras_merged: clamp_to_u32(extras),
888 ..Default::default()
889 },
890 ctx,
891 ))
892 }
893}
894
895#[cfg(test)]
896mod tests {
897 use super::*;
898 use byokey_auth::flow::LoginProgress;
899
900 #[test]
901 fn progress_to_pb_started() {
902 let ev = progress_to_pb(&LoginProgress::Started);
903 assert_eq!(ev.stage, acct::LoginStage::LOGIN_STAGE_STARTED);
904 assert!(ev.message.is_none());
905 assert!(ev.user_code.is_none());
906 }
907
908 #[test]
909 fn progress_to_pb_opened_browser_auth_code() {
910 let ev = progress_to_pb(&LoginProgress::OpenedBrowser {
911 url: "https://example.com/auth".into(),
912 user_code: None,
913 });
914 assert_eq!(ev.stage, acct::LoginStage::LOGIN_STAGE_OPENED_BROWSER);
915 assert_eq!(ev.message.as_deref(), Some("https://example.com/auth"));
916 assert!(ev.user_code.is_none());
917 }
918
919 #[test]
920 fn progress_to_pb_opened_browser_device_code() {
921 let ev = progress_to_pb(&LoginProgress::OpenedBrowser {
922 url: "https://github.com/login/device".into(),
923 user_code: Some("ABCD-1234".into()),
924 });
925 assert_eq!(ev.stage, acct::LoginStage::LOGIN_STAGE_OPENED_BROWSER);
926 assert_eq!(
927 ev.message.as_deref(),
928 Some("https://github.com/login/device")
929 );
930 assert_eq!(ev.user_code.as_deref(), Some("ABCD-1234"));
931 }
932
933 #[test]
934 fn progress_to_pb_got_code() {
935 let ev = progress_to_pb(&LoginProgress::GotCode);
936 assert_eq!(ev.stage, acct::LoginStage::LOGIN_STAGE_GOT_CODE);
937 }
938
939 #[test]
940 fn progress_to_pb_exchanging() {
941 let ev = progress_to_pb(&LoginProgress::Exchanging);
942 assert_eq!(ev.stage, acct::LoginStage::LOGIN_STAGE_EXCHANGING);
943 }
944}