Skip to main content

context69_sdk/
client.rs

1use std::{
2    sync::Arc,
3    time::Duration,
4};
5
6use context69_contracts::{
7    ApiErrorResponse, AuthLoginRequest, AuthMeResponse, AuthTokenResponse, DocumentResponse,
8    GroupMemberResponse, GroupResponse, HealthResponse, ProjectMemberResponse, ProjectResponse,
9    SearchRequest, SearchResponse,
10};
11use reqwest::{
12    Method, RequestBuilder, Response, StatusCode, Url,
13    header::{AUTHORIZATION, USER_AGENT},
14};
15use tokio::sync::RwLock;
16
17use crate::Error;
18
19#[derive(Debug, Clone, Default)]
20struct SessionState {
21    access_token: Option<String>,
22}
23
24#[derive(Clone)]
25pub struct Context69Client {
26    client: reqwest::Client,
27    base_url: Url,
28    session: Arc<RwLock<SessionState>>,
29}
30
31pub struct Context69ClientBuilder {
32    base_url: Option<Url>,
33    user_agent: Option<String>,
34    timeout: Option<Duration>,
35}
36
37impl Context69Client {
38    pub fn builder() -> Context69ClientBuilder {
39        Context69ClientBuilder {
40            base_url: None,
41            user_agent: None,
42            timeout: None,
43        }
44    }
45
46    pub async fn login(
47        &self,
48        login_name: impl Into<String>,
49        password: impl Into<String>,
50    ) -> Result<AuthTokenResponse, Error> {
51        let response = self
52            .client
53            .post(self.url("/v1/auth/login")?)
54            .json(&AuthLoginRequest {
55                login_name: login_name.into(),
56                password: password.into(),
57            })
58            .send()
59            .await?;
60        let payload: AuthTokenResponse = self.read_json_response(response).await?;
61        self.set_access_token(payload.access_token.clone()).await;
62        Ok(payload)
63    }
64
65    pub async fn refresh(&self) -> Result<AuthTokenResponse, Error> {
66        let response = self
67            .client
68            .post(self.url("/v1/auth/refresh")?)
69            .send()
70            .await?;
71        if !response.status().is_success() {
72            let status = response.status();
73            let body = response.text().await.unwrap_or_default();
74            let message = parse_api_error_message(&body).unwrap_or_else(|| body.clone());
75            return Err(Error::RefreshFailed {
76                status: Some(status),
77                message,
78            });
79        }
80
81        let payload = response.json::<AuthTokenResponse>().await?;
82        self.set_access_token(payload.access_token.clone()).await;
83        Ok(payload)
84    }
85
86    pub async fn logout(&self) -> Result<(), Error> {
87        let response = self
88            .client
89            .post(self.url("/v1/auth/logout")?)
90            .send()
91            .await?;
92        self.read_empty_response(response).await?;
93        self.clear_access_token().await;
94        Ok(())
95    }
96
97    pub async fn me(&self) -> Result<AuthMeResponse, Error> {
98        self.send_json(Method::GET, "/v1/auth/me", None::<&()>).await
99    }
100
101    pub async fn healthz(&self) -> Result<HealthResponse, Error> {
102        let response = self.client.get(self.url("/healthz")?).send().await?;
103        self.read_json_response(response).await
104    }
105
106    pub async fn search(&self, request: SearchRequest) -> Result<SearchResponse, Error> {
107        self.send_json(Method::POST, "/v1/search", Some(&request)).await
108    }
109
110    pub async fn get_document(&self, document_id: i64) -> Result<DocumentResponse, Error> {
111        let path = format!("/v1/documents/{document_id}");
112        self.send_json(Method::GET, &path, None::<&()>).await
113    }
114
115    pub async fn list_groups(&self) -> Result<Vec<GroupResponse>, Error> {
116        self.send_json(Method::GET, "/v1/groups", None::<&()>).await
117    }
118
119    pub async fn get_group(&self, group_key: &str) -> Result<GroupResponse, Error> {
120        let path = format!("/v1/groups/{group_key}");
121        self.send_json(Method::GET, &path, None::<&()>).await
122    }
123
124    pub async fn list_projects(&self, group_key: &str) -> Result<Vec<ProjectResponse>, Error> {
125        let path = format!("/v1/groups/{group_key}/projects");
126        self.send_json(Method::GET, &path, None::<&()>).await
127    }
128
129    pub async fn get_project(
130        &self,
131        group_key: &str,
132        project_key: &str,
133    ) -> Result<ProjectResponse, Error> {
134        let path = format!("/v1/groups/{group_key}/projects/{project_key}");
135        self.send_json(Method::GET, &path, None::<&()>).await
136    }
137
138    pub async fn list_group_members(
139        &self,
140        group_key: &str,
141    ) -> Result<Vec<GroupMemberResponse>, Error> {
142        let path = format!("/v1/groups/{group_key}/members");
143        self.send_json(Method::GET, &path, None::<&()>).await
144    }
145
146    pub async fn list_project_members(
147        &self,
148        group_key: &str,
149        project_key: &str,
150    ) -> Result<Vec<ProjectMemberResponse>, Error> {
151        let path = format!("/v1/groups/{group_key}/projects/{project_key}/members");
152        self.send_json(Method::GET, &path, None::<&()>).await
153    }
154
155    pub async fn list_sources(&self) -> Result<Vec<context69_contracts::SourceStatus>, Error> {
156        let response: Vec<context69_contracts::SourceStatus> =
157            self.send_json(Method::GET, "/v1/sources", None::<&()>).await?;
158        Ok(response)
159    }
160
161    async fn send_json<TReq, TRes>(
162        &self,
163        method: Method,
164        path: &str,
165        body: Option<TReq>,
166    ) -> Result<TRes, Error>
167    where
168        TReq: serde::Serialize,
169        TRes: serde::de::DeserializeOwned,
170    {
171        self.ensure_authenticated().await?;
172        let request_body = body
173            .map(|value| serde_json::to_value(value))
174            .transpose()
175            .map_err(Error::from)?;
176
177        let response = self
178            .send_with_refresh(method.clone(), path, request_body.clone())
179            .await?;
180        self.read_json_response(response).await
181    }
182
183    async fn send_with_refresh(
184        &self,
185        method: Method,
186        path: &str,
187        body: Option<serde_json::Value>,
188    ) -> Result<Response, Error> {
189        let response = self
190            .send_request(method.clone(), path, body.clone(), true)
191            .await?;
192
193        if response.status() != StatusCode::UNAUTHORIZED {
194            return Ok(response);
195        }
196
197        match self.refresh().await {
198            Ok(_) => self.send_request(method, path, body, true).await,
199            Err(Error::RefreshFailed { status, message }) => {
200                Err(Error::RefreshFailed { status, message })
201            }
202            Err(other) => Err(Error::RefreshFailed {
203                status: None,
204                message: other.to_string(),
205            }),
206        }
207    }
208
209    async fn send_request(
210        &self,
211        method: Method,
212        path: &str,
213        body: Option<serde_json::Value>,
214        include_auth: bool,
215    ) -> Result<Response, Error> {
216        let url = self.url(path)?;
217        let mut request = self.client.request(method, url);
218        if include_auth {
219            request = self.authorized(request).await?;
220        }
221        if let Some(body) = body {
222            request = request.json(&body);
223        }
224        Ok(request.send().await?)
225    }
226
227    async fn authorized(&self, request: RequestBuilder) -> Result<RequestBuilder, Error> {
228        let token = self
229            .session
230            .read()
231            .await
232            .access_token
233            .clone()
234            .ok_or(Error::AuthenticationRequired)?;
235        Ok(request.header(AUTHORIZATION, format!("Bearer {token}")))
236    }
237
238    async fn ensure_authenticated(&self) -> Result<(), Error> {
239        if self.session.read().await.access_token.is_some() {
240            Ok(())
241        } else {
242            Err(Error::AuthenticationRequired)
243        }
244    }
245
246    fn url(&self, path: &str) -> Result<Url, Error> {
247        self.base_url
248            .join(path.trim_start_matches('/'))
249            .map_err(|source| Error::UrlJoin {
250                path: path.to_string(),
251                source,
252            })
253    }
254
255    async fn set_access_token(&self, token: String) {
256        self.session.write().await.access_token = Some(token);
257    }
258
259    async fn clear_access_token(&self) {
260        self.session.write().await.access_token = None;
261    }
262
263    async fn read_empty_response(&self, response: Response) -> Result<(), Error> {
264        let status = response.status();
265        if status.is_success() {
266            return Ok(());
267        }
268        Err(self.build_http_error(response).await)
269    }
270
271    async fn read_json_response<T: serde::de::DeserializeOwned>(
272        &self,
273        response: Response,
274    ) -> Result<T, Error> {
275        let status = response.status();
276        if !status.is_success() {
277            return Err(self.build_http_error(response).await);
278        }
279        Ok(response.json::<T>().await?)
280    }
281
282    async fn build_http_error(&self, response: Response) -> Error {
283        let status = response.status();
284        let body = response.text().await.unwrap_or_default();
285        Error::HttpStatus {
286            status,
287            api_error: parse_api_error_message(&body),
288            body,
289        }
290    }
291}
292
293impl Context69ClientBuilder {
294    pub fn base_url(mut self, base_url: &str) -> Result<Self, Error> {
295        let mut url =
296            Url::parse(base_url).map_err(|_| Error::InvalidBaseUrl(base_url.to_string()))?;
297        if !url.path().ends_with('/') {
298            let next_path = format!("{}/", url.path());
299            url.set_path(&next_path);
300        }
301        self.base_url = Some(url);
302        Ok(self)
303    }
304
305    pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
306        self.user_agent = Some(user_agent.into());
307        self
308    }
309
310    pub fn timeout(mut self, timeout: Duration) -> Result<Self, Error> {
311        if timeout.is_zero() {
312            return Err(Error::InvalidTimeout(timeout));
313        }
314        self.timeout = Some(timeout);
315        Ok(self)
316    }
317
318    pub fn build(self) -> Result<Context69Client, Error> {
319        let base_url = self
320            .base_url
321            .ok_or_else(|| Error::InvalidBaseUrl("missing base_url".to_string()))?;
322        let mut builder = reqwest::Client::builder().cookie_store(true);
323        if let Some(user_agent) = self.user_agent {
324            builder = builder.default_headers({
325                let mut headers = reqwest::header::HeaderMap::new();
326                headers.insert(
327                    USER_AGENT,
328                    user_agent
329                        .parse()
330                        .map_err(|_| Error::InvalidHeader(user_agent.clone()))?,
331                );
332                headers
333            });
334        }
335        if let Some(timeout) = self.timeout {
336            builder = builder.timeout(timeout);
337        }
338        let client = builder.build()?;
339        Ok(Context69Client {
340            client,
341            base_url,
342            session: Arc::new(RwLock::new(SessionState::default())),
343        })
344    }
345}
346
347fn parse_api_error_message(body: &str) -> Option<String> {
348    serde_json::from_str::<ApiErrorResponse>(body)
349        .ok()
350        .map(|value| value.error)
351}
352
353#[cfg(test)]
354mod tests {
355    use std::sync::{
356        Arc,
357        atomic::{AtomicUsize, Ordering},
358    };
359
360    use super::*;
361    use axum::{
362        Json, Router,
363        extract::State,
364        http::{HeaderMap, StatusCode},
365        response::IntoResponse,
366        routing::{get, post},
367    };
368    use context69_contracts::{AuthUserResponse, GroupKind, HealthStatus, MembershipRole, SearchHit, Visibility};
369    use serde_json::json;
370    use tokio::net::TcpListener;
371
372    #[derive(Clone, Default)]
373    struct TestState {
374        search_calls: Arc<AtomicUsize>,
375        refresh_calls: Arc<AtomicUsize>,
376    }
377
378    async fn spawn_test_server() -> (String, TestState) {
379        let state = TestState::default();
380        let app = Router::new()
381            .route("/healthz", get(health_handler))
382            .route("/v1/auth/login", post(login_handler))
383            .route("/v1/auth/refresh", post(refresh_handler))
384            .route("/v1/search", post(search_handler))
385            .route("/v1/groups", get(groups_handler))
386            .with_state(state.clone());
387
388        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind listener");
389        let addr = listener.local_addr().expect("local addr");
390        tokio::spawn(async move {
391            axum::serve(listener, app).await.expect("serve app");
392        });
393        (format!("http://{addr}"), state)
394    }
395
396    async fn health_handler() -> Json<HealthResponse> {
397        Json(HealthResponse {
398            status: HealthStatus::Ok,
399            indexed_chunks: Some(7),
400            db_ok: None,
401            qdrant_ok: None,
402        })
403    }
404
405    async fn login_handler() -> impl IntoResponse {
406        (
407            [(
408                "set-cookie",
409                "context69_refresh=refresh-ok; HttpOnly; Path=/",
410            )],
411            Json(token_response("token-initial")),
412        )
413    }
414
415    async fn refresh_handler(State(state): State<TestState>) -> impl IntoResponse {
416        state.refresh_calls.fetch_add(1, Ordering::SeqCst);
417        (
418            [(
419                "set-cookie",
420                "context69_refresh=refresh-ok; HttpOnly; Path=/",
421            )],
422            Json(token_response("token-refreshed")),
423        )
424    }
425
426    async fn search_handler(
427        State(state): State<TestState>,
428        headers: HeaderMap,
429        Json(request): Json<SearchRequest>,
430    ) -> impl IntoResponse {
431        let call = state.search_calls.fetch_add(1, Ordering::SeqCst);
432        let bearer = headers
433            .get(AUTHORIZATION)
434            .and_then(|value| value.to_str().ok())
435            .unwrap_or_default();
436
437        if request.query == "bad request" {
438            return (
439                StatusCode::BAD_REQUEST,
440                Json(ApiErrorResponse {
441                    error: "invalid query".to_string(),
442                }),
443            )
444                .into_response();
445        }
446
447        if call == 0 && bearer == "Bearer token-initial" {
448            return (
449                StatusCode::UNAUTHORIZED,
450                Json(ApiErrorResponse {
451                    error: "expired".to_string(),
452                }),
453            )
454                .into_response();
455        }
456
457        if bearer != "Bearer token-refreshed" {
458            return (
459                StatusCode::UNAUTHORIZED,
460                Json(ApiErrorResponse {
461                    error: "missing bearer token".to_string(),
462                }),
463            )
464                .into_response();
465        }
466
467        Json(SearchResponse {
468            query: request.query,
469            hits: vec![SearchHit {
470                chunk_id: uuid::Uuid::nil(),
471                document_id: 42,
472                group_key: "team".to_string(),
473                project_key: "docs".to_string(),
474                visibility: Visibility::Private,
475                source_key: "source".to_string(),
476                external_id: "ext-1".to_string(),
477                title: "Document".to_string(),
478                summary: Some("Summary".to_string()),
479                source_uri: "https://example.test/doc".to_string(),
480                published_at: None,
481                chunk_index: 0,
482                chunk_text: "hello".to_string(),
483                score: 0.9,
484                vector_score: Some(0.9),
485                keyword_score: None,
486                rerank_score: None,
487                match_reason: None,
488                metadata_json: json!({}),
489                library_file_id: None,
490                library_section_label: None,
491                library_path: None,
492                is_library_file: false,
493            }],
494        })
495        .into_response()
496    }
497
498    async fn groups_handler(headers: HeaderMap) -> impl IntoResponse {
499        let bearer = headers
500            .get(AUTHORIZATION)
501            .and_then(|value| value.to_str().ok())
502            .unwrap_or_default();
503        if bearer != "Bearer token-refreshed" {
504            return (
505                StatusCode::UNAUTHORIZED,
506                Json(ApiErrorResponse {
507                    error: "missing bearer token".to_string(),
508                }),
509            )
510                .into_response();
511        }
512
513        Json(vec![GroupResponse {
514            group_id: 1,
515            group_key: "team".to_string(),
516            parent_group_key: None,
517            name: "Team".to_string(),
518            visibility: Visibility::Private,
519            kind: GroupKind::Shared,
520            current_role: Some(MembershipRole::Owner),
521            created_at: chrono::Utc::now(),
522            updated_at: chrono::Utc::now(),
523        }])
524        .into_response()
525    }
526
527    fn token_response(access_token: &str) -> AuthTokenResponse {
528        AuthTokenResponse {
529            access_token: access_token.to_string(),
530            token_type: "Bearer".to_string(),
531            expires_in_secs: 3600,
532            user: AuthUserResponse {
533                user_id: 1,
534                login_name: "admin".to_string(),
535                display_name: "Administrator".to_string(),
536                is_admin: true,
537                disabled_at: None,
538                personal_group_key: "admin".to_string(),
539                personal_group_role: Some(MembershipRole::Owner),
540            },
541        }
542    }
543
544    #[test]
545    fn builder_normalizes_base_url_with_trailing_slash() {
546        let client = Context69Client::builder()
547            .base_url("http://localhost:8096")
548            .expect("base url")
549            .build()
550            .expect("client");
551
552        assert_eq!(client.url("/healthz").expect("url").as_str(), "http://localhost:8096/healthz");
553    }
554
555    #[test]
556    fn parse_api_error_body() {
557        let body = r#"{"error":"missing bearer token"}"#;
558        assert_eq!(
559            parse_api_error_message(body),
560            Some("missing bearer token".to_string())
561        );
562    }
563
564    #[tokio::test]
565    async fn protected_api_requires_login() {
566        let (base_url, _) = spawn_test_server().await;
567        let client = Context69Client::builder()
568            .base_url(&base_url)
569            .expect("base url")
570            .build()
571            .expect("client");
572
573        let error = client
574            .list_groups()
575            .await
576            .expect_err("should require authentication");
577        assert!(matches!(error, Error::AuthenticationRequired));
578    }
579
580    #[tokio::test]
581    async fn search_refreshes_once_and_retries() {
582        let (base_url, state) = spawn_test_server().await;
583        let client = Context69Client::builder()
584            .base_url(&base_url)
585            .expect("base url")
586            .build()
587            .expect("client");
588
589        client.login("admin", "secret").await.expect("login");
590        let response = client
591            .search(SearchRequest {
592                query: "policy".to_string(),
593                limit: 8,
594                source_key: None,
595                group_key: None,
596                project_key: None,
597                published_after: None,
598                published_before: None,
599            })
600            .await
601            .expect("search response");
602
603        assert_eq!(response.hits.len(), 1);
604        assert_eq!(state.refresh_calls.load(Ordering::SeqCst), 1);
605        assert_eq!(state.search_calls.load(Ordering::SeqCst), 2);
606    }
607
608    #[tokio::test]
609    async fn surfaces_api_error_message() {
610        let (base_url, _) = spawn_test_server().await;
611        let client = Context69Client::builder()
612            .base_url(&base_url)
613            .expect("base url")
614            .build()
615            .expect("client");
616
617        client.login("admin", "secret").await.expect("login");
618        let error = client
619            .search(SearchRequest {
620                query: "bad request".to_string(),
621                limit: 8,
622                source_key: None,
623                group_key: None,
624                project_key: None,
625                published_after: None,
626                published_before: None,
627            })
628            .await
629            .expect_err("should fail");
630
631        match error {
632            Error::HttpStatus {
633                status,
634                api_error,
635                ..
636            } => {
637                assert_eq!(status, StatusCode::BAD_REQUEST);
638                assert_eq!(api_error.as_deref(), Some("invalid query"));
639            }
640            other => panic!("unexpected error: {other}"),
641        }
642    }
643}