Skip to main content

kagi_sdk/session_web/
mod.rs

1pub mod models;
2
3use crate::{
4    client::KagiClient,
5    error::KagiError,
6    parsing::{
7        parse_html_search_response, parse_kagi_failure_payload, parse_summarize_response,
8        parse_summary_stream_response,
9    },
10    routing::EndpointId,
11    transport::{RequestBody, TransportResponse},
12};
13use scraper::{Html, Selector};
14
15#[allow(deprecated)]
16use self::models::{
17    HtmlSearchRequest, HtmlSearchResponse, SummaryLabsTextRequest, SummaryLabsUrlRequest,
18};
19use self::models::{
20    SearchRequest, SearchResponse, SummarizeRequest, SummarizeResponse, SummaryStreamResponse,
21};
22
23#[derive(Debug)]
24pub struct SessionWeb<'a> {
25    client: &'a KagiClient,
26}
27
28impl<'a> SessionWeb<'a> {
29    pub(crate) fn new(client: &'a KagiClient) -> Self {
30        Self { client }
31    }
32
33    pub async fn search(&self, request: SearchRequest) -> Result<SearchResponse, KagiError> {
34        let response = self
35            .request(
36                EndpointId::SessionHtmlSearch,
37                request.into_query(),
38                RequestBody::Empty,
39            )
40            .await?;
41
42        if (300..400).contains(&response.status) {
43            return Err(KagiError::InvalidSession {
44                endpoint: response.endpoint,
45                status: response.status,
46                message: format!(
47                    "session-web request redirected to {:?}",
48                    response.redirect_location
49                ),
50            });
51        }
52
53        if response.status >= 400 {
54            return Err(build_http_api_failure(&response));
55        }
56
57        if let Some((code, message)) = parse_kagi_failure_payload(&response.body) {
58            return Err(KagiError::ApiFailure {
59                endpoint: response.endpoint,
60                status: response.status,
61                code,
62                message,
63            });
64        }
65
66        let parsed_search = parse_html_search_response(response.endpoint, &response.body);
67        if parsed_search.is_ok() {
68            return parsed_search;
69        }
70
71        if response_looks_like_kagi_auth_interstitial(&response) {
72            return Err(KagiError::InvalidSession {
73                endpoint: response.endpoint,
74                status: response.status,
75                message: "response matched Kagi auth interstitial structure".to_string(),
76            });
77        }
78
79        parsed_search
80    }
81
82    pub async fn summarize(
83        &self,
84        request: SummarizeRequest,
85    ) -> Result<SummarizeResponse, KagiError> {
86        let response = self.request_summarize(request, false).await?;
87
88        validate_non_search_session_response(&response)?;
89        parse_summarize_response(response.endpoint, &response.body)
90    }
91
92    pub async fn summarize_stream(
93        &self,
94        request: SummarizeRequest,
95    ) -> Result<SummaryStreamResponse, KagiError> {
96        let response = self.request_summarize(request, true).await?;
97
98        validate_non_search_session_response(&response)?;
99        parse_summary_stream_response(response.endpoint, &response.body)
100    }
101
102    #[deprecated(note = "use search(SearchRequest) instead")]
103    #[doc(hidden)]
104    #[allow(deprecated)]
105    pub async fn html_search(
106        &self,
107        request: HtmlSearchRequest,
108    ) -> Result<HtmlSearchResponse, KagiError> {
109        self.search(request).await
110    }
111
112    #[deprecated(
113        note = "use summarize(...) or summarize_stream(...) with SummarizeRequest instead"
114    )]
115    #[doc(hidden)]
116    #[allow(deprecated)]
117    pub async fn summary_labs_url(
118        &self,
119        request: SummaryLabsUrlRequest,
120    ) -> Result<SummaryStreamResponse, KagiError> {
121        self.summarize_stream(request.into_summarize_request())
122            .await
123    }
124
125    #[deprecated(
126        note = "use summarize(...) or summarize_stream(...) with SummarizeRequest instead"
127    )]
128    #[doc(hidden)]
129    #[allow(deprecated)]
130    pub async fn summary_labs_text(
131        &self,
132        request: SummaryLabsTextRequest,
133    ) -> Result<SummaryStreamResponse, KagiError> {
134        self.summarize_stream(request.into_summarize_request())
135            .await
136    }
137
138    async fn request(
139        &self,
140        endpoint: EndpointId,
141        query: Vec<(String, String)>,
142        body: RequestBody,
143    ) -> Result<TransportResponse, KagiError> {
144        self.client
145            .transport()
146            .execute(self.client.credentials(), endpoint, &query, body)
147            .await
148    }
149
150    async fn request_summarize(
151        &self,
152        request: SummarizeRequest,
153        stream: bool,
154    ) -> Result<TransportResponse, KagiError> {
155        if let Some(query) = request.clone().into_query(stream) {
156            return self
157                .request(EndpointId::SessionSummaryLabsGet, query, RequestBody::Empty)
158                .await;
159        }
160
161        let form = request
162            .into_form(stream)
163            .expect("text summarize request must produce form fields");
164
165        self.request(
166            EndpointId::SessionSummaryLabsPost,
167            Vec::new(),
168            RequestBody::Form(form),
169        )
170        .await
171    }
172}
173
174fn validate_non_search_session_response(response: &TransportResponse) -> Result<(), KagiError> {
175    if (300..400).contains(&response.status) {
176        return Err(KagiError::InvalidSession {
177            endpoint: response.endpoint,
178            status: response.status,
179            message: format!(
180                "session-web request redirected to {:?}",
181                response.redirect_location
182            ),
183        });
184    }
185
186    if response_looks_like_kagi_auth_interstitial(response) {
187        return Err(KagiError::InvalidSession {
188            endpoint: response.endpoint,
189            status: response.status,
190            message: "response matched Kagi auth interstitial structure".to_string(),
191        });
192    }
193
194    if let Some((code, message)) = parse_kagi_failure_payload(&response.body) {
195        return Err(KagiError::ApiFailure {
196            endpoint: response.endpoint,
197            status: response.status,
198            code,
199            message,
200        });
201    }
202
203    if response.status >= 400 {
204        return Err(build_http_api_failure(response));
205    }
206
207    Ok(())
208}
209
210fn build_http_api_failure(response: &TransportResponse) -> KagiError {
211    if let Some((code, message)) = parse_kagi_failure_payload(&response.body) {
212        return KagiError::ApiFailure {
213            endpoint: response.endpoint,
214            status: response.status,
215            code,
216            message,
217        };
218    }
219
220    let fallback_message = if response.body.trim().is_empty() {
221        format!(
222            "HTTP {} returned without parseable Kagi failure payload",
223            response.status
224        )
225    } else {
226        response.body.clone()
227    };
228
229    KagiError::ApiFailure {
230        endpoint: response.endpoint,
231        status: response.status,
232        code: None,
233        message: fallback_message,
234    }
235}
236
237fn response_looks_like_kagi_auth_interstitial(response: &TransportResponse) -> bool {
238    if !response_looks_like_html(response) {
239        return false;
240    }
241
242    let document = Html::parse_document(&response.body);
243    let login_form_selector =
244        Selector::parse("form[action='/auth/login'], form[action*='/auth/login']")
245            .expect("static selector must compile");
246    let password_input_selector = Selector::parse("input[type='password'], input[name='password']")
247        .expect("static selector must compile");
248    let auth_link_selector = Selector::parse("a[href='/auth/login'], a[href*='/auth/login']")
249        .expect("static selector must compile");
250    let auth_heading_selector =
251        Selector::parse("title, h1, h2").expect("static selector must compile");
252
253    let has_auth_form_with_password = document
254        .select(&login_form_selector)
255        .any(|form| form.select(&password_input_selector).next().is_some());
256    if has_auth_form_with_password {
257        return true;
258    }
259
260    let has_auth_link = document.select(&auth_link_selector).next().is_some();
261    if !has_auth_link {
262        return false;
263    }
264
265    document.select(&auth_heading_selector).any(|node| {
266        let heading_text = node
267            .text()
268            .collect::<Vec<_>>()
269            .join(" ")
270            .to_ascii_lowercase();
271
272        heading_text.contains("sign in")
273            || heading_text.contains("log in")
274            || heading_text.contains("session expired")
275            || heading_text.contains("invalid session")
276    })
277}
278
279fn response_looks_like_html(response: &TransportResponse) -> bool {
280    if response
281        .content_type
282        .as_deref()
283        .is_some_and(|value| value.to_ascii_lowercase().contains("text/html"))
284    {
285        return true;
286    }
287
288    let trimmed_body = response.body.trim_start().to_ascii_lowercase();
289    trimmed_body.starts_with("<!doctype html") || trimmed_body.starts_with("<html")
290}