dynamo_async_openai/
threads.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
5// Original Copyright (c) 2022 Himanshu Neema
6// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
7//
8// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
9// Licensed under Apache 2.0
10
11use crate::{
12    Client, Messages, Runs,
13    config::Config,
14    error::OpenAIError,
15    types::{
16        AssistantEventStream, CreateThreadAndRunRequest, CreateThreadRequest, DeleteThreadResponse,
17        ModifyThreadRequest, RunObject, ThreadObject,
18    },
19};
20
21/// Create threads that assistants can interact with.
22///
23/// Related guide: [Assistants](https://platform.openai.com/docs/assistants/overview)
24pub struct Threads<'c, C: Config> {
25    client: &'c Client<C>,
26}
27
28impl<'c, C: Config> Threads<'c, C> {
29    pub fn new(client: &'c Client<C>) -> Self {
30        Self { client }
31    }
32
33    /// Call [Messages] group API to manage message in [thread_id] thread.
34    pub fn messages(&self, thread_id: &str) -> Messages<C> {
35        Messages::new(self.client, thread_id)
36    }
37
38    /// Call [Runs] group API to manage runs in [thread_id] thread.
39    pub fn runs(&self, thread_id: &str) -> Runs<C> {
40        Runs::new(self.client, thread_id)
41    }
42
43    /// Create a thread and run it in one request.
44    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
45    pub async fn create_and_run(
46        &self,
47        request: CreateThreadAndRunRequest,
48    ) -> Result<RunObject, OpenAIError> {
49        self.client.post("/threads/runs", request).await
50    }
51
52    /// Create a thread and run it in one request (streaming).
53    ///
54    /// byot: You must ensure "stream: true" in serialized `request`
55    #[crate::byot(
56        T0 = serde::Serialize,
57        R = serde::de::DeserializeOwned,
58        stream = "true",
59        where_clause = "R: std::marker::Send + 'static + TryFrom<eventsource_stream::Event, Error = OpenAIError>"
60    )]
61    #[allow(unused_mut)]
62    pub async fn create_and_run_stream(
63        &self,
64        mut request: CreateThreadAndRunRequest,
65    ) -> Result<AssistantEventStream, OpenAIError> {
66        #[cfg(not(feature = "byot"))]
67        {
68            if request.stream.is_some() && !request.stream.unwrap() {
69                return Err(OpenAIError::InvalidArgument(
70                    "When stream is false, use Threads::create_and_run".into(),
71                ));
72            }
73
74            request.stream = Some(true);
75        }
76        Ok(self
77            .client
78            .post_stream_mapped_raw_events("/threads/runs", request, TryFrom::try_from)
79            .await)
80    }
81
82    /// Create a thread.
83    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
84    pub async fn create(&self, request: CreateThreadRequest) -> Result<ThreadObject, OpenAIError> {
85        self.client.post("/threads", request).await
86    }
87
88    /// Retrieves a thread.
89    #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
90    pub async fn retrieve(&self, thread_id: &str) -> Result<ThreadObject, OpenAIError> {
91        self.client.get(&format!("/threads/{thread_id}")).await
92    }
93
94    /// Modifies a thread.
95    #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
96    pub async fn update(
97        &self,
98        thread_id: &str,
99        request: ModifyThreadRequest,
100    ) -> Result<ThreadObject, OpenAIError> {
101        self.client
102            .post(&format!("/threads/{thread_id}"), request)
103            .await
104    }
105
106    /// Delete a thread.
107    #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
108    pub async fn delete(&self, thread_id: &str) -> Result<DeleteThreadResponse, OpenAIError> {
109        self.client.delete(&format!("/threads/{thread_id}")).await
110    }
111}