Skip to main content

openai_oxide/resources/beta/
threads.rs

1// Threads resource — client.beta().threads()
2
3use super::BETA_HEADER;
4use crate::client::OpenAI;
5use crate::error::OpenAIError;
6use crate::pagination::{Page, Paginator};
7use crate::types::beta::{
8    Message, MessageCreateRequest, MessageList, MessageListParams, Thread, ThreadCreateRequest,
9    ThreadDeleted,
10};
11
12/// Access thread endpoints (beta).
13///
14/// API reference: <https://platform.openai.com/docs/api-reference/threads>
15pub struct Threads<'a> {
16    client: &'a OpenAI,
17}
18
19impl<'a> Threads<'a> {
20    pub(crate) fn new(client: &'a OpenAI) -> Self {
21        Self { client }
22    }
23
24    /// Create a thread.
25    ///
26    /// `POST /threads`
27    pub async fn create(&self, request: ThreadCreateRequest) -> Result<Thread, OpenAIError> {
28        let response = self
29            .client
30            .request(reqwest::Method::POST, "/threads")
31            .header(BETA_HEADER.0, BETA_HEADER.1)
32            .json(&request)
33            .send()
34            .await?;
35        OpenAI::handle_response(response).await
36    }
37
38    /// Retrieve a thread.
39    ///
40    /// `GET /threads/{thread_id}`
41    pub async fn retrieve(&self, thread_id: &str) -> Result<Thread, OpenAIError> {
42        let response = self
43            .client
44            .request(reqwest::Method::GET, &format!("/threads/{thread_id}"))
45            .header(BETA_HEADER.0, BETA_HEADER.1)
46            .send()
47            .await?;
48        OpenAI::handle_response(response).await
49    }
50
51    /// Delete a thread.
52    ///
53    /// `DELETE /threads/{thread_id}`
54    pub async fn delete(&self, thread_id: &str) -> Result<ThreadDeleted, OpenAIError> {
55        let response = self
56            .client
57            .request(reqwest::Method::DELETE, &format!("/threads/{thread_id}"))
58            .header(BETA_HEADER.0, BETA_HEADER.1)
59            .send()
60            .await?;
61        OpenAI::handle_response(response).await
62    }
63
64    /// Access messages sub-resource.
65    pub fn messages(&self, thread_id: &str) -> Messages<'_> {
66        Messages {
67            client: self.client,
68            thread_id: thread_id.to_string(),
69        }
70    }
71}
72
73/// Thread messages sub-resource.
74pub struct Messages<'a> {
75    client: &'a OpenAI,
76    thread_id: String,
77}
78
79impl<'a> Messages<'a> {
80    /// Create a message in a thread.
81    ///
82    /// `POST /threads/{thread_id}/messages`
83    pub async fn create(&self, request: MessageCreateRequest) -> Result<Message, OpenAIError> {
84        let response = self
85            .client
86            .request(
87                reqwest::Method::POST,
88                &format!("/threads/{}/messages", self.thread_id),
89            )
90            .header(BETA_HEADER.0, BETA_HEADER.1)
91            .json(&request)
92            .send()
93            .await?;
94        OpenAI::handle_response(response).await
95    }
96
97    /// List messages in a thread.
98    ///
99    /// `GET /threads/{thread_id}/messages`
100    pub async fn list(&self) -> Result<MessageList, OpenAIError> {
101        let response = self
102            .client
103            .request(
104                reqwest::Method::GET,
105                &format!("/threads/{}/messages", self.thread_id),
106            )
107            .header(BETA_HEADER.0, BETA_HEADER.1)
108            .send()
109            .await?;
110        OpenAI::handle_response(response).await
111    }
112
113    /// List messages with pagination parameters.
114    ///
115    /// `GET /threads/{thread_id}/messages`
116    pub async fn list_page(&self, params: MessageListParams) -> Result<MessageList, OpenAIError> {
117        let response = self
118            .client
119            .request(
120                reqwest::Method::GET,
121                &format!("/threads/{}/messages", self.thread_id),
122            )
123            .header(BETA_HEADER.0, BETA_HEADER.1)
124            .query(&params.to_query())
125            .send()
126            .await?;
127        OpenAI::handle_response(response).await
128    }
129
130    /// Auto-paginate through all messages in a thread.
131    pub fn list_auto(&self, params: MessageListParams) -> Paginator<Message> {
132        let client = self.client.clone();
133        let thread_id = self.thread_id.clone();
134        let base_params = params;
135        Paginator::new(move |cursor| {
136            let client = client.clone();
137            let thread_id = thread_id.clone();
138            let mut params = base_params.clone();
139            if cursor.is_some() {
140                params.after = cursor;
141            }
142            async move {
143                let response = client
144                    .request(
145                        reqwest::Method::GET,
146                        &format!("/threads/{thread_id}/messages"),
147                    )
148                    .header(BETA_HEADER.0, BETA_HEADER.1)
149                    .query(&params.to_query())
150                    .send()
151                    .await?;
152                let list: MessageList = OpenAI::handle_response(response).await?;
153                let after_cursor = list
154                    .last_id
155                    .clone()
156                    .or_else(|| list.data.last().map(|m| m.id.clone()));
157                Ok(Page {
158                    has_more: list.has_more.unwrap_or(false),
159                    after_cursor,
160                    data: list.data,
161                })
162            }
163        })
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use crate::OpenAI;
170    use crate::config::ClientConfig;
171    use crate::types::beta::ThreadCreateRequest;
172
173    #[tokio::test]
174    async fn test_threads_create() {
175        let mut server = mockito::Server::new_async().await;
176        let mock = server
177            .mock("POST", "/threads")
178            .match_header("OpenAI-Beta", "assistants=v2")
179            .with_status(200)
180            .with_header("content-type", "application/json")
181            .with_body(
182                r#"{
183                    "id": "thread_abc123",
184                    "object": "thread",
185                    "created_at": 1699012949
186                }"#,
187            )
188            .create_async()
189            .await;
190
191        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
192        let thread = client
193            .beta()
194            .threads()
195            .create(ThreadCreateRequest::default())
196            .await
197            .unwrap();
198        assert_eq!(thread.id, "thread_abc123");
199        mock.assert_async().await;
200    }
201
202    #[tokio::test]
203    async fn test_threads_delete() {
204        let mut server = mockito::Server::new_async().await;
205        let mock = server
206            .mock("DELETE", "/threads/thread_abc123")
207            .match_header("OpenAI-Beta", "assistants=v2")
208            .with_status(200)
209            .with_header("content-type", "application/json")
210            .with_body(r#"{"id": "thread_abc123", "object": "thread.deleted", "deleted": true}"#)
211            .create_async()
212            .await;
213
214        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
215        let resp = client
216            .beta()
217            .threads()
218            .delete("thread_abc123")
219            .await
220            .unwrap();
221        assert!(resp.deleted);
222        mock.assert_async().await;
223    }
224}