Skip to main content

nyquest_backend_winhttp/async/
client.rs

1//! Async WinHTTP client implementation.
2
3use std::future::Future;
4#[cfg(feature = "async-stream")]
5use std::ops::Range;
6use std::sync::Arc;
7
8use futures_channel::oneshot;
9use nyquest_interface::client::ClientOptions;
10use nyquest_interface::r#async::{AsyncBackend, AsyncClient, Request};
11use nyquest_interface::{Error as NyquestError, Result as NyquestResult};
12
13use super::callback::setup_session_callback;
14use super::context::{RequestContext, RequestState};
15use super::response::WinHttpAsyncResponse;
16use crate::error::{WinHttpError, WinHttpResultExt};
17use crate::handle::RequestHandle;
18use crate::r#async::state_fut::wait_for_state;
19use crate::r#async::threadpool::submit_callback;
20use crate::request::{
21    create_request, method_to_cwstr, prepare_additional_headers, prepare_body, PreparedBody,
22};
23use crate::session::WinHttpSession;
24use crate::stream::{DataOrStream, StreamWriter};
25use crate::url::{concat_url, ParsedUrl};
26use crate::WinHttpBackend;
27
28#[cfg(feature = "async-stream")]
29use nyquest_interface::r#async::BoxedStream;
30
31/// Async WinHTTP client.
32#[derive(Clone)]
33pub struct WinHttpAsyncClient {
34    session: Arc<WinHttpSession>,
35}
36
37impl WinHttpAsyncClient {
38    pub(crate) async fn new(options: ClientOptions) -> NyquestResult<Self> {
39        // Create async session
40        let session = WinHttpSession::new(options, true).into_nyquest()?;
41
42        // Set up the callback on the session
43        setup_session_callback(&session.session).into_nyquest()?;
44
45        Ok(Self { session })
46    }
47}
48
49impl AsyncClient for WinHttpAsyncClient {
50    type Response = WinHttpAsyncResponse;
51
52    fn request(&self, req: Request) -> impl Future<Output = NyquestResult<Self::Response>> + Send {
53        let session = self.session.clone();
54        async move {
55            // Prepare headers and body before spawning to threadpool
56            let mut prepared_body;
57
58            // Create the request context
59            let ctx = RequestContext::new();
60
61            let body_len;
62            let (setup_tx, setup_rx) = oneshot::channel();
63            submit_callback({
64                let url = concat_url(session.base_cwurl.as_deref(), &req.relative_uri)?;
65                let method = method_to_cwstr(&req.method);
66                prepared_body = prepare_body(req.body, get_stream_content_length);
67                let headers_str = prepare_additional_headers(
68                    &req.additional_headers,
69                    &session.options,
70                    &prepared_body,
71                );
72
73                body_len = prepared_body.body_len();
74                let is_stream = matches!(prepared_body, PreparedBody::Stream { .. });
75                // Store body data in context - it must remain valid until SENDREQUEST_COMPLETE
76                ctx.set_body(prepared_body.take_body().unwrap_or_default());
77
78                let ctx = Arc::downgrade(&ctx);
79                let session = session.clone();
80                move || {
81                    let parsed_url = match ParsedUrl::parse(&url) {
82                        Some(p) => p,
83                        None => {
84                            let _ = setup_tx.send(Err(NyquestError::InvalidUrl));
85                            return;
86                        }
87                    };
88
89                    let (connection, request) = match create_request(&session, &parsed_url, &method)
90                    {
91                        Ok(handles) => handles,
92                        Err(e) => {
93                            let _ = setup_tx.send(Err(e.into()));
94                            return;
95                        }
96                    };
97                    drop(session);
98                    let Some(ctx) = ctx.upgrade() else {
99                        return;
100                    };
101                    let result = if headers_str.is_empty() {
102                        Ok(())
103                    } else {
104                        request.add_headers(&headers_str)
105                    };
106                    let result = result.and_then(|()| {
107                        let context = Arc::into_raw(ctx.clone()) as usize;
108                        let res = match (is_stream, body_len) {
109                            (true, Some(len)) => request.send_with_total_length(len, context),
110                            (true, None) => request.send_chunked(context),
111                            (false, _) => {
112                                let (body_ptr, body_len) = ctx.get_body_ptr();
113                                unsafe { request.send(body_ptr, body_len, context) }
114                            }
115                        };
116                        if res.is_err() {
117                            let _ = unsafe { Arc::from_raw(context as *const RequestContext) };
118                        }
119                        res
120                    });
121
122                    let _ = setup_tx.send(result.map(|()| (connection, request)).into_nyquest());
123                }
124            })?;
125
126            // Wait for the setup to complete
127            let (connection, request) = setup_rx.await.map_err(|_| {
128                nyquest_interface::Error::Io(std::io::Error::other("setup channel closed"))
129            })??;
130
131            wait_for_state(&*ctx, RequestState::HeadersSent).await?;
132
133            // If streaming, poll the stream writer to send data
134            #[cfg(feature = "async-stream")]
135            if let PreparedBody::Stream { stream_parts, .. } = prepared_body {
136                poll_stream_upload(&ctx, &request, stream_parts, body_len.is_none()).await?;
137            }
138
139            request.receive_response().into_nyquest()?;
140
141            // Now wait for headers to be available
142            wait_for_state(&*ctx, RequestState::HeadersReceived).await?;
143
144            // Build the response
145            let status = request.query_status_code()?;
146            let content_length = request.query_content_length();
147
148            Ok(WinHttpAsyncResponse::new(
149                ctx,
150                status,
151                content_length,
152                session.options.max_response_buffer_size,
153                session.clone(),
154                connection,
155                request,
156            ))
157        }
158    }
159}
160
161/// Extracts content length from a BoxedStream if it's a sized stream.
162#[cfg(feature = "async-stream")]
163fn get_stream_content_length(stream: &BoxedStream) -> Option<u64> {
164    match stream {
165        BoxedStream::Sized { content_length, .. } => Some(*content_length),
166        BoxedStream::Unsized { .. } => None,
167    }
168}
169
170#[cfg(not(feature = "async-stream"))]
171fn get_stream_content_length(_stream: &impl Sized) -> Option<u64> {
172    None
173}
174
175/// Polls the stream writer to send data chunks via WinHttpWriteData.
176#[cfg(feature = "async-stream")]
177async fn poll_stream_upload(
178    ctx: &RequestContext,
179    request: &RequestHandle,
180    stream_parts: Vec<DataOrStream<BoxedStream>>,
181    is_chunked: bool,
182) -> NyquestResult<()> {
183    // Create stream writer
184    let mut writer = StreamWriter::new(stream_parts, is_chunked);
185
186    while !writer.is_finished() {
187        let (buf, range) = writer
188            .take_buffer(|stream, buf, cx| {
189                use nyquest_interface::r#async::futures_io::AsyncRead as _;
190                use std::pin::Pin;
191
192                Pin::new(stream).poll_read(cx, buf)
193            })
194            .await?;
195        let buf = write_all_data_async(ctx, request, buf, range).await?;
196        writer.advance(buf);
197    }
198    Ok(())
199}
200
201/// Writes data asynchronously via WinHttpWriteData and waits for completion.
202#[cfg(feature = "async-stream")]
203async fn write_all_data_async(
204    ctx: &RequestContext,
205    request: &RequestHandle,
206    data: Vec<u8>,
207    mut range: Range<usize>,
208) -> NyquestResult<Vec<u8>> {
209    if data.is_empty() {
210        return Ok(Vec::new());
211    }
212
213    // Store the data in the context so it remains valid during the async operation
214    ctx.set_write_buffer(data.into());
215    while !range.is_empty() {
216        let ptr = ctx.prepare_for_writing();
217
218        let result = unsafe {
219            windows_sys::Win32::Networking::WinHttp::WinHttpWriteData(
220                request.as_raw(),
221                ptr.add(range.start) as *const std::ffi::c_void,
222                range.len() as u32,
223                std::ptr::null_mut(),
224            )
225        };
226
227        if result == 0 {
228            return Err(WinHttpError::from_last_error("WinHttpWriteData").into());
229        }
230
231        // Wait for WRITE_COMPLETE callback
232        let res = wait_for_state(ctx, RequestState::WriteComplete).await?;
233        range.start += res.bytes_transferred;
234    }
235    Ok(ctx.take_write_buffer().into_owned())
236}
237
238impl AsyncBackend for WinHttpBackend {
239    type AsyncClient = WinHttpAsyncClient;
240
241    async fn create_async_client(
242        &self,
243        options: ClientOptions,
244    ) -> NyquestResult<Self::AsyncClient> {
245        WinHttpAsyncClient::new(options).await
246    }
247}