1use std::{
2 sync::Arc,
3 time::Duration,
4};
5
6use context69_contracts::{
7 ApiErrorResponse, AuthLoginRequest, AuthMeResponse, AuthTokenResponse, DocumentResponse,
8 GroupMemberResponse, GroupResponse, HealthResponse, ProjectMemberResponse, ProjectResponse,
9 SearchRequest, SearchResponse,
10};
11use reqwest::{
12 Method, RequestBuilder, Response, StatusCode, Url,
13 header::{AUTHORIZATION, USER_AGENT},
14};
15use tokio::sync::RwLock;
16
17use crate::Error;
18
19#[derive(Debug, Clone, Default)]
20struct SessionState {
21 access_token: Option<String>,
22}
23
24#[derive(Clone)]
25pub struct Context69Client {
26 client: reqwest::Client,
27 base_url: Url,
28 session: Arc<RwLock<SessionState>>,
29}
30
31pub struct Context69ClientBuilder {
32 base_url: Option<Url>,
33 user_agent: Option<String>,
34 timeout: Option<Duration>,
35}
36
37impl Context69Client {
38 pub fn builder() -> Context69ClientBuilder {
39 Context69ClientBuilder {
40 base_url: None,
41 user_agent: None,
42 timeout: None,
43 }
44 }
45
46 pub async fn login(
47 &self,
48 login_name: impl Into<String>,
49 password: impl Into<String>,
50 ) -> Result<AuthTokenResponse, Error> {
51 let response = self
52 .client
53 .post(self.url("/v1/auth/login")?)
54 .json(&AuthLoginRequest {
55 login_name: login_name.into(),
56 password: password.into(),
57 })
58 .send()
59 .await?;
60 let payload: AuthTokenResponse = self.read_json_response(response).await?;
61 self.set_access_token(payload.access_token.clone()).await;
62 Ok(payload)
63 }
64
65 pub async fn refresh(&self) -> Result<AuthTokenResponse, Error> {
66 let response = self
67 .client
68 .post(self.url("/v1/auth/refresh")?)
69 .send()
70 .await?;
71 if !response.status().is_success() {
72 let status = response.status();
73 let body = response.text().await.unwrap_or_default();
74 let message = parse_api_error_message(&body).unwrap_or_else(|| body.clone());
75 return Err(Error::RefreshFailed {
76 status: Some(status),
77 message,
78 });
79 }
80
81 let payload = response.json::<AuthTokenResponse>().await?;
82 self.set_access_token(payload.access_token.clone()).await;
83 Ok(payload)
84 }
85
86 pub async fn logout(&self) -> Result<(), Error> {
87 let response = self
88 .client
89 .post(self.url("/v1/auth/logout")?)
90 .send()
91 .await?;
92 self.read_empty_response(response).await?;
93 self.clear_access_token().await;
94 Ok(())
95 }
96
97 pub async fn me(&self) -> Result<AuthMeResponse, Error> {
98 self.send_json(Method::GET, "/v1/auth/me", None::<&()>).await
99 }
100
101 pub async fn healthz(&self) -> Result<HealthResponse, Error> {
102 let response = self.client.get(self.url("/healthz")?).send().await?;
103 self.read_json_response(response).await
104 }
105
106 pub async fn search(&self, request: SearchRequest) -> Result<SearchResponse, Error> {
107 self.send_json(Method::POST, "/v1/search", Some(&request)).await
108 }
109
110 pub async fn get_document(&self, document_id: i64) -> Result<DocumentResponse, Error> {
111 let path = format!("/v1/documents/{document_id}");
112 self.send_json(Method::GET, &path, None::<&()>).await
113 }
114
115 pub async fn list_groups(&self) -> Result<Vec<GroupResponse>, Error> {
116 self.send_json(Method::GET, "/v1/groups", None::<&()>).await
117 }
118
119 pub async fn get_group(&self, group_key: &str) -> Result<GroupResponse, Error> {
120 let path = format!("/v1/groups/{group_key}");
121 self.send_json(Method::GET, &path, None::<&()>).await
122 }
123
124 pub async fn list_projects(&self, group_key: &str) -> Result<Vec<ProjectResponse>, Error> {
125 let path = format!("/v1/groups/{group_key}/projects");
126 self.send_json(Method::GET, &path, None::<&()>).await
127 }
128
129 pub async fn get_project(
130 &self,
131 group_key: &str,
132 project_key: &str,
133 ) -> Result<ProjectResponse, Error> {
134 let path = format!("/v1/groups/{group_key}/projects/{project_key}");
135 self.send_json(Method::GET, &path, None::<&()>).await
136 }
137
138 pub async fn list_group_members(
139 &self,
140 group_key: &str,
141 ) -> Result<Vec<GroupMemberResponse>, Error> {
142 let path = format!("/v1/groups/{group_key}/members");
143 self.send_json(Method::GET, &path, None::<&()>).await
144 }
145
146 pub async fn list_project_members(
147 &self,
148 group_key: &str,
149 project_key: &str,
150 ) -> Result<Vec<ProjectMemberResponse>, Error> {
151 let path = format!("/v1/groups/{group_key}/projects/{project_key}/members");
152 self.send_json(Method::GET, &path, None::<&()>).await
153 }
154
155 pub async fn list_sources(&self) -> Result<Vec<context69_contracts::SourceStatus>, Error> {
156 let response: Vec<context69_contracts::SourceStatus> =
157 self.send_json(Method::GET, "/v1/sources", None::<&()>).await?;
158 Ok(response)
159 }
160
161 async fn send_json<TReq, TRes>(
162 &self,
163 method: Method,
164 path: &str,
165 body: Option<TReq>,
166 ) -> Result<TRes, Error>
167 where
168 TReq: serde::Serialize,
169 TRes: serde::de::DeserializeOwned,
170 {
171 self.ensure_authenticated().await?;
172 let request_body = body
173 .map(|value| serde_json::to_value(value))
174 .transpose()
175 .map_err(Error::from)?;
176
177 let response = self
178 .send_with_refresh(method.clone(), path, request_body.clone())
179 .await?;
180 self.read_json_response(response).await
181 }
182
183 async fn send_with_refresh(
184 &self,
185 method: Method,
186 path: &str,
187 body: Option<serde_json::Value>,
188 ) -> Result<Response, Error> {
189 let response = self
190 .send_request(method.clone(), path, body.clone(), true)
191 .await?;
192
193 if response.status() != StatusCode::UNAUTHORIZED {
194 return Ok(response);
195 }
196
197 match self.refresh().await {
198 Ok(_) => self.send_request(method, path, body, true).await,
199 Err(Error::RefreshFailed { status, message }) => {
200 Err(Error::RefreshFailed { status, message })
201 }
202 Err(other) => Err(Error::RefreshFailed {
203 status: None,
204 message: other.to_string(),
205 }),
206 }
207 }
208
209 async fn send_request(
210 &self,
211 method: Method,
212 path: &str,
213 body: Option<serde_json::Value>,
214 include_auth: bool,
215 ) -> Result<Response, Error> {
216 let url = self.url(path)?;
217 let mut request = self.client.request(method, url);
218 if include_auth {
219 request = self.authorized(request).await?;
220 }
221 if let Some(body) = body {
222 request = request.json(&body);
223 }
224 Ok(request.send().await?)
225 }
226
227 async fn authorized(&self, request: RequestBuilder) -> Result<RequestBuilder, Error> {
228 let token = self
229 .session
230 .read()
231 .await
232 .access_token
233 .clone()
234 .ok_or(Error::AuthenticationRequired)?;
235 Ok(request.header(AUTHORIZATION, format!("Bearer {token}")))
236 }
237
238 async fn ensure_authenticated(&self) -> Result<(), Error> {
239 if self.session.read().await.access_token.is_some() {
240 Ok(())
241 } else {
242 Err(Error::AuthenticationRequired)
243 }
244 }
245
246 fn url(&self, path: &str) -> Result<Url, Error> {
247 self.base_url
248 .join(path.trim_start_matches('/'))
249 .map_err(|source| Error::UrlJoin {
250 path: path.to_string(),
251 source,
252 })
253 }
254
255 async fn set_access_token(&self, token: String) {
256 self.session.write().await.access_token = Some(token);
257 }
258
259 async fn clear_access_token(&self) {
260 self.session.write().await.access_token = None;
261 }
262
263 async fn read_empty_response(&self, response: Response) -> Result<(), Error> {
264 let status = response.status();
265 if status.is_success() {
266 return Ok(());
267 }
268 Err(self.build_http_error(response).await)
269 }
270
271 async fn read_json_response<T: serde::de::DeserializeOwned>(
272 &self,
273 response: Response,
274 ) -> Result<T, Error> {
275 let status = response.status();
276 if !status.is_success() {
277 return Err(self.build_http_error(response).await);
278 }
279 Ok(response.json::<T>().await?)
280 }
281
282 async fn build_http_error(&self, response: Response) -> Error {
283 let status = response.status();
284 let body = response.text().await.unwrap_or_default();
285 Error::HttpStatus {
286 status,
287 api_error: parse_api_error_message(&body),
288 body,
289 }
290 }
291}
292
293impl Context69ClientBuilder {
294 pub fn base_url(mut self, base_url: &str) -> Result<Self, Error> {
295 let mut url =
296 Url::parse(base_url).map_err(|_| Error::InvalidBaseUrl(base_url.to_string()))?;
297 if !url.path().ends_with('/') {
298 let next_path = format!("{}/", url.path());
299 url.set_path(&next_path);
300 }
301 self.base_url = Some(url);
302 Ok(self)
303 }
304
305 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
306 self.user_agent = Some(user_agent.into());
307 self
308 }
309
310 pub fn timeout(mut self, timeout: Duration) -> Result<Self, Error> {
311 if timeout.is_zero() {
312 return Err(Error::InvalidTimeout(timeout));
313 }
314 self.timeout = Some(timeout);
315 Ok(self)
316 }
317
318 pub fn build(self) -> Result<Context69Client, Error> {
319 let base_url = self
320 .base_url
321 .ok_or_else(|| Error::InvalidBaseUrl("missing base_url".to_string()))?;
322 let mut builder = reqwest::Client::builder().cookie_store(true);
323 if let Some(user_agent) = self.user_agent {
324 builder = builder.default_headers({
325 let mut headers = reqwest::header::HeaderMap::new();
326 headers.insert(
327 USER_AGENT,
328 user_agent
329 .parse()
330 .map_err(|_| Error::InvalidHeader(user_agent.clone()))?,
331 );
332 headers
333 });
334 }
335 if let Some(timeout) = self.timeout {
336 builder = builder.timeout(timeout);
337 }
338 let client = builder.build()?;
339 Ok(Context69Client {
340 client,
341 base_url,
342 session: Arc::new(RwLock::new(SessionState::default())),
343 })
344 }
345}
346
347fn parse_api_error_message(body: &str) -> Option<String> {
348 serde_json::from_str::<ApiErrorResponse>(body)
349 .ok()
350 .map(|value| value.error)
351}
352
353#[cfg(test)]
354mod tests {
355 use std::sync::{
356 Arc,
357 atomic::{AtomicUsize, Ordering},
358 };
359
360 use super::*;
361 use axum::{
362 Json, Router,
363 extract::State,
364 http::{HeaderMap, StatusCode},
365 response::IntoResponse,
366 routing::{get, post},
367 };
368 use context69_contracts::{AuthUserResponse, GroupKind, HealthStatus, MembershipRole, SearchHit, Visibility};
369 use serde_json::json;
370 use tokio::net::TcpListener;
371
372 #[derive(Clone, Default)]
373 struct TestState {
374 search_calls: Arc<AtomicUsize>,
375 refresh_calls: Arc<AtomicUsize>,
376 }
377
378 async fn spawn_test_server() -> (String, TestState) {
379 let state = TestState::default();
380 let app = Router::new()
381 .route("/healthz", get(health_handler))
382 .route("/v1/auth/login", post(login_handler))
383 .route("/v1/auth/refresh", post(refresh_handler))
384 .route("/v1/search", post(search_handler))
385 .route("/v1/groups", get(groups_handler))
386 .with_state(state.clone());
387
388 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind listener");
389 let addr = listener.local_addr().expect("local addr");
390 tokio::spawn(async move {
391 axum::serve(listener, app).await.expect("serve app");
392 });
393 (format!("http://{addr}"), state)
394 }
395
396 async fn health_handler() -> Json<HealthResponse> {
397 Json(HealthResponse {
398 status: HealthStatus::Ok,
399 indexed_chunks: Some(7),
400 db_ok: None,
401 qdrant_ok: None,
402 })
403 }
404
405 async fn login_handler() -> impl IntoResponse {
406 (
407 [(
408 "set-cookie",
409 "context69_refresh=refresh-ok; HttpOnly; Path=/",
410 )],
411 Json(token_response("token-initial")),
412 )
413 }
414
415 async fn refresh_handler(State(state): State<TestState>) -> impl IntoResponse {
416 state.refresh_calls.fetch_add(1, Ordering::SeqCst);
417 (
418 [(
419 "set-cookie",
420 "context69_refresh=refresh-ok; HttpOnly; Path=/",
421 )],
422 Json(token_response("token-refreshed")),
423 )
424 }
425
426 async fn search_handler(
427 State(state): State<TestState>,
428 headers: HeaderMap,
429 Json(request): Json<SearchRequest>,
430 ) -> impl IntoResponse {
431 let call = state.search_calls.fetch_add(1, Ordering::SeqCst);
432 let bearer = headers
433 .get(AUTHORIZATION)
434 .and_then(|value| value.to_str().ok())
435 .unwrap_or_default();
436
437 if request.query == "bad request" {
438 return (
439 StatusCode::BAD_REQUEST,
440 Json(ApiErrorResponse {
441 error: "invalid query".to_string(),
442 }),
443 )
444 .into_response();
445 }
446
447 if call == 0 && bearer == "Bearer token-initial" {
448 return (
449 StatusCode::UNAUTHORIZED,
450 Json(ApiErrorResponse {
451 error: "expired".to_string(),
452 }),
453 )
454 .into_response();
455 }
456
457 if bearer != "Bearer token-refreshed" {
458 return (
459 StatusCode::UNAUTHORIZED,
460 Json(ApiErrorResponse {
461 error: "missing bearer token".to_string(),
462 }),
463 )
464 .into_response();
465 }
466
467 Json(SearchResponse {
468 query: request.query,
469 hits: vec![SearchHit {
470 chunk_id: uuid::Uuid::nil(),
471 document_id: 42,
472 group_key: "team".to_string(),
473 project_key: "docs".to_string(),
474 visibility: Visibility::Private,
475 source_key: "source".to_string(),
476 external_id: "ext-1".to_string(),
477 title: "Document".to_string(),
478 summary: Some("Summary".to_string()),
479 source_uri: "https://example.test/doc".to_string(),
480 published_at: None,
481 chunk_index: 0,
482 chunk_text: "hello".to_string(),
483 score: 0.9,
484 vector_score: Some(0.9),
485 keyword_score: None,
486 rerank_score: None,
487 match_reason: None,
488 metadata_json: json!({}),
489 library_file_id: None,
490 library_section_label: None,
491 library_path: None,
492 is_library_file: false,
493 }],
494 })
495 .into_response()
496 }
497
498 async fn groups_handler(headers: HeaderMap) -> impl IntoResponse {
499 let bearer = headers
500 .get(AUTHORIZATION)
501 .and_then(|value| value.to_str().ok())
502 .unwrap_or_default();
503 if bearer != "Bearer token-refreshed" {
504 return (
505 StatusCode::UNAUTHORIZED,
506 Json(ApiErrorResponse {
507 error: "missing bearer token".to_string(),
508 }),
509 )
510 .into_response();
511 }
512
513 Json(vec![GroupResponse {
514 group_id: 1,
515 group_key: "team".to_string(),
516 parent_group_key: None,
517 name: "Team".to_string(),
518 visibility: Visibility::Private,
519 kind: GroupKind::Shared,
520 current_role: Some(MembershipRole::Owner),
521 created_at: chrono::Utc::now(),
522 updated_at: chrono::Utc::now(),
523 }])
524 .into_response()
525 }
526
527 fn token_response(access_token: &str) -> AuthTokenResponse {
528 AuthTokenResponse {
529 access_token: access_token.to_string(),
530 token_type: "Bearer".to_string(),
531 expires_in_secs: 3600,
532 user: AuthUserResponse {
533 user_id: 1,
534 login_name: "admin".to_string(),
535 display_name: "Administrator".to_string(),
536 is_admin: true,
537 disabled_at: None,
538 personal_group_key: "admin".to_string(),
539 personal_group_role: Some(MembershipRole::Owner),
540 },
541 }
542 }
543
544 #[test]
545 fn builder_normalizes_base_url_with_trailing_slash() {
546 let client = Context69Client::builder()
547 .base_url("http://localhost:8096")
548 .expect("base url")
549 .build()
550 .expect("client");
551
552 assert_eq!(client.url("/healthz").expect("url").as_str(), "http://localhost:8096/healthz");
553 }
554
555 #[test]
556 fn parse_api_error_body() {
557 let body = r#"{"error":"missing bearer token"}"#;
558 assert_eq!(
559 parse_api_error_message(body),
560 Some("missing bearer token".to_string())
561 );
562 }
563
564 #[tokio::test]
565 async fn protected_api_requires_login() {
566 let (base_url, _) = spawn_test_server().await;
567 let client = Context69Client::builder()
568 .base_url(&base_url)
569 .expect("base url")
570 .build()
571 .expect("client");
572
573 let error = client
574 .list_groups()
575 .await
576 .expect_err("should require authentication");
577 assert!(matches!(error, Error::AuthenticationRequired));
578 }
579
580 #[tokio::test]
581 async fn search_refreshes_once_and_retries() {
582 let (base_url, state) = spawn_test_server().await;
583 let client = Context69Client::builder()
584 .base_url(&base_url)
585 .expect("base url")
586 .build()
587 .expect("client");
588
589 client.login("admin", "secret").await.expect("login");
590 let response = client
591 .search(SearchRequest {
592 query: "policy".to_string(),
593 limit: 8,
594 source_key: None,
595 group_key: None,
596 project_key: None,
597 published_after: None,
598 published_before: None,
599 })
600 .await
601 .expect("search response");
602
603 assert_eq!(response.hits.len(), 1);
604 assert_eq!(state.refresh_calls.load(Ordering::SeqCst), 1);
605 assert_eq!(state.search_calls.load(Ordering::SeqCst), 2);
606 }
607
608 #[tokio::test]
609 async fn surfaces_api_error_message() {
610 let (base_url, _) = spawn_test_server().await;
611 let client = Context69Client::builder()
612 .base_url(&base_url)
613 .expect("base url")
614 .build()
615 .expect("client");
616
617 client.login("admin", "secret").await.expect("login");
618 let error = client
619 .search(SearchRequest {
620 query: "bad request".to_string(),
621 limit: 8,
622 source_key: None,
623 group_key: None,
624 project_key: None,
625 published_after: None,
626 published_before: None,
627 })
628 .await
629 .expect_err("should fail");
630
631 match error {
632 Error::HttpStatus {
633 status,
634 api_error,
635 ..
636 } => {
637 assert_eq!(status, StatusCode::BAD_REQUEST);
638 assert_eq!(api_error.as_deref(), Some("invalid query"));
639 }
640 other => panic!("unexpected error: {other}"),
641 }
642 }
643}