openai_api/
lib.rs

1///! `OpenAI` API client library
2#[macro_use]
3extern crate derive_builder;
4
5use thiserror::Error;
6
7type Result<T> = std::result::Result<T, Error>;
8
9#[allow(clippy::default_trait_access)]
10pub mod api {
11    //! Data types corresponding to requests and responses from the API
12    use std::{collections::HashMap, fmt::Display};
13
14    use super::Client;
15    use serde::{Deserialize, Serialize};
16
17    /// Container type. Used in the api, but not useful for clients of this library
18    #[derive(Deserialize, Debug)]
19    pub(super) struct Container<T> {
20        pub data: Vec<T>,
21    }
22
23    /// Engine description type
24    #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
25    pub struct EngineInfo {
26        pub id: Engine,
27        pub owner: String,
28        pub ready: bool,
29    }
30
31    /// Engine types, known and unknown
32    #[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
33    #[serde(rename_all = "kebab-case")]
34    #[non_exhaustive] // prevent clients from matching on every option
35    pub enum Engine {
36        Ada,
37        Babbage,
38        Curie,
39        Davinci,
40        #[serde(rename = "content-filter-alpha-c4")]
41        ContentFilter,
42        #[serde(other)]
43        Other,
44    }
45
46    // Custom Display to lowercase things
47    impl std::fmt::Display for Engine {
48        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49            match self {
50                Engine::Ada => f.write_str("ada"),
51                Engine::Babbage => f.write_str("babbage"),
52                Engine::Curie => f.write_str("curie"),
53                Engine::Davinci => f.write_str("davinci"),
54                Engine::ContentFilter => f.write_str("content-filter-alpha-c4"),
55                _ => panic!("Can't write out Other engine id"),
56            }
57        }
58    }
59
60    /// Options that affect the result
61    #[derive(Serialize, Debug, Builder, Clone)]
62    pub struct CompletionArgs {
63        #[builder(setter(into), default = "\"<|endoftext|>\".into()")]
64        prompt: String,
65        #[builder(default = "Engine::Davinci")]
66        #[serde(skip_serializing)]
67        pub(super) engine: Engine,
68        #[builder(default = "16")]
69        max_tokens: u64,
70        #[builder(default = "1.0")]
71        temperature: f64,
72        #[builder(default = "1.0")]
73        top_p: f64,
74        #[builder(default = "1")]
75        n: u64,
76        #[builder(setter(strip_option), default)]
77        logprobs: Option<u64>,
78        #[builder(default = "false")]
79        echo: bool,
80        #[builder(setter(strip_option), default)]
81        stop: Option<Vec<String>>,
82        #[builder(default = "0.0")]
83        presence_penalty: f64,
84        #[builder(default = "0.0")]
85        frequency_penalty: f64,
86        #[builder(default)]
87        logit_bias: HashMap<String, f64>,
88    }
89
90    /* {
91          "stream": false, // SSE streams back results
92          "best_of": Option<u64>, //cant be used with stream
93       }
94    */
95    // TODO: add validators for the different arguments
96
97    impl From<&str> for CompletionArgs {
98        fn from(prompt_string: &str) -> Self {
99            Self {
100                prompt: prompt_string.into(),
101                ..CompletionArgsBuilder::default()
102                    .build()
103                    .expect("default should build")
104            }
105        }
106    }
107
108    impl CompletionArgs {
109        /// Build a `CompletionArgs` from the defaults
110        #[must_use]
111        pub fn builder() -> CompletionArgsBuilder {
112            CompletionArgsBuilder::default()
113        }
114
115        /// Request a completion from the api
116        ///
117        /// # Errors
118        /// `Error::APIError` if the api returns an error
119        #[cfg(feature = "async")]
120        pub async fn complete_prompt(self, client: &Client) -> super::Result<Completion> {
121            client.complete_prompt(self).await
122        }
123
124        #[cfg(feature = "sync")]
125        pub fn complete_prompt_sync(self, client: &Client) -> super::Result<Completion> {
126            client.complete_prompt_sync(self)
127        }
128    }
129
130    impl CompletionArgsBuilder {
131        /// Request a completion from the api
132        ///
133        /// # Errors
134        /// `Error::BadArguments` if the arguments to complete are not valid
135        /// `Error::APIError` if the api returns an error
136        #[cfg(feature = "async")]
137        pub async fn complete_prompt(&self, client: &Client) -> super::Result<Completion> {
138            client.complete_prompt(self.build()?).await
139        }
140
141        #[cfg(feature = "sync")]
142        pub fn complete_prompt_sync(&self, client: &Client) -> super::Result<Completion> {
143            client.complete_prompt_sync(self.build()?)
144        }
145    }
146
147    /// Represents a non-streamed completion response
148    #[derive(Deserialize, Debug, Clone)]
149    pub struct Completion {
150        /// Completion unique identifier
151        pub id: String,
152        /// Unix timestamp when the completion was generated
153        pub created: u64,
154        /// Exact model type and version used for the completion
155        pub model: String,
156        /// Timestamp
157        pub choices: Vec<Choice>,
158    }
159
160    impl std::fmt::Display for Completion {
161        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162            write!(f, "{}", self.choices[0])
163        }
164    }
165
166    /// A single completion result
167    #[derive(Deserialize, Debug, Clone)]
168    pub struct Choice {
169        /// The text of the completion. Will contain the prompt if echo is True.
170        pub text: String,
171        /// Offset in the result where the completion began. Useful if using echo.
172        pub index: u64,
173        /// If requested, the log probabilities of the completion tokens
174        pub logprobs: Option<LogProbs>,
175        /// Why the completion ended when it did
176        pub finish_reason: FinishReason,
177    }
178
179    impl std::fmt::Display for Choice {
180        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181            self.text.fmt(f)
182        }
183    }
184
185    /// Represents a logprobs subdocument
186    #[derive(Deserialize, Debug, Clone)]
187    pub struct LogProbs {
188        pub tokens: Vec<String>,
189        pub token_logprobs: Vec<Option<f64>>,
190        pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
191        pub text_offset: Vec<u64>,
192    }
193
194    /// Reason a prompt completion finished.
195    #[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
196    #[non_exhaustive]
197    pub enum FinishReason {
198        /// The maximum length was reached
199        #[serde(rename = "length")]
200        MaxTokensReached,
201        /// The stop token was encountered
202        #[serde(rename = "stop")]
203        StopSequenceReached,
204    }
205
206    /// Error response object from the server
207    #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
208    pub struct ErrorMessage {
209        pub message: String,
210        #[serde(rename = "type")]
211        pub error_type: String,
212    }
213
214    impl Display for ErrorMessage {
215        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216            self.message.fmt(f)
217        }
218    }
219
220    /// API-level wrapper used in deserialization
221    #[derive(Deserialize, Debug)]
222    pub(crate) struct ErrorWrapper {
223        pub error: ErrorMessage,
224    }
225}
226
227/// This library's main `Error` type.
228#[derive(Error, Debug)]
229pub enum Error {
230    /// An error returned by the API itself
231    #[error("API returned an Error: {}", .0.message)]
232    APIError(api::ErrorMessage),
233    /// An error the client discovers before talking to the API
234    #[error("Bad arguments: {0}")]
235    BadArguments(String),
236    /// Network / protocol related errors
237    #[cfg(feature = "async")]
238    #[error("Error at the protocol level: {0}")]
239    AsyncProtocolError(surf::Error),
240    #[cfg(feature = "sync")]
241    #[error("Error at the protocol level, sync client")]
242    SyncProtocolError(ureq::Error),
243}
244
245impl From<api::ErrorMessage> for Error {
246    fn from(e: api::ErrorMessage) -> Self {
247        Error::APIError(e)
248    }
249}
250
251impl From<String> for Error {
252    fn from(e: String) -> Self {
253        Error::BadArguments(e)
254    }
255}
256
257#[cfg(feature = "async")]
258impl From<surf::Error> for Error {
259    fn from(e: surf::Error) -> Self {
260        Error::AsyncProtocolError(e)
261    }
262}
263
264#[cfg(feature = "sync")]
265impl From<ureq::Error> for Error {
266    fn from(e: ureq::Error) -> Self {
267        Error::SyncProtocolError(e)
268    }
269}
270
271/// Authentication middleware
272#[cfg(feature = "async")]
273struct BearerToken {
274    token: String,
275}
276
277#[cfg(feature = "async")]
278impl std::fmt::Debug for BearerToken {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        // Get the first few characters to help debug, but not accidentally log key
281        write!(
282            f,
283            r#"Bearer {{ token: "{}" }}"#,
284            self.token.get(0..8).ok_or(std::fmt::Error)?
285        )
286    }
287}
288
289#[cfg(feature = "async")]
290impl BearerToken {
291    fn new(token: &str) -> Self {
292        Self {
293            token: String::from(token),
294        }
295    }
296}
297
298#[cfg(feature = "async")]
299#[surf::utils::async_trait]
300impl surf::middleware::Middleware for BearerToken {
301    async fn handle(
302        &self,
303        mut req: surf::Request,
304        client: surf::Client,
305        next: surf::middleware::Next<'_>,
306    ) -> surf::Result<surf::Response> {
307        log::debug!("Request: {:?}", req);
308        req.insert_header("Authorization", format!("Bearer {}", self.token));
309        let response: surf::Response = next.run(req, client).await?;
310        log::debug!("Response: {:?}", response);
311        Ok(response)
312    }
313}
314
315#[cfg(feature = "async")]
316fn async_client(token: &str, base_url: &str) -> surf::Client {
317    let mut async_client = surf::client();
318    async_client.set_base_url(surf::Url::parse(base_url).expect("Static string should parse"));
319    async_client.with(BearerToken::new(token))
320}
321
322#[cfg(feature = "sync")]
323fn sync_client(token: &str) -> ureq::Agent {
324    ureq::agent().auth_kind("Bearer", token).build()
325}
326
327/// Client object. Must be constructed to talk to the API.
328#[derive(Debug, Clone)]
329pub struct Client {
330    #[cfg(feature = "async")]
331    async_client: surf::Client,
332    #[cfg(feature = "sync")]
333    sync_client: ureq::Agent,
334    #[cfg(feature = "sync")]
335    base_url: String,
336}
337
338impl Client {
339    // Creates a new `Client` given an api token
340    #[must_use]
341    pub fn new(token: &str) -> Self {
342        let base_url = String::from("https://api.openai.com/v1/");
343        Self {
344            #[cfg(feature = "async")]
345            async_client: async_client(token, &base_url),
346            #[cfg(feature = "sync")]
347            sync_client: sync_client(token),
348            #[cfg(feature = "sync")]
349            base_url,
350        }
351    }
352
353    // Allow setting the api root in the tests
354    #[cfg(test)]
355    fn set_api_root(mut self, base_url: &str) -> Self {
356        #[cfg(feature = "async")]
357        {
358            self.async_client.set_base_url(
359                surf::Url::parse(base_url).expect("static URL expected to parse correctly"),
360            );
361        }
362        #[cfg(feature = "sync")]
363        {
364            self.base_url = String::from(base_url);
365        }
366        self
367    }
368
369    /// Private helper for making gets
370    #[cfg(feature = "async")]
371    async fn get<T>(&self, endpoint: &str) -> Result<T>
372    where
373        T: serde::de::DeserializeOwned,
374    {
375        let mut response = self.async_client.get(endpoint).await?;
376        if let surf::StatusCode::Ok = response.status() {
377            Ok(response.body_json::<T>().await?)
378        } else {
379            let err = response.body_json::<api::ErrorWrapper>().await?.error;
380            Err(Error::APIError(err))
381        }
382    }
383
384    #[cfg(feature = "sync")]
385    fn get_sync<T>(&self, endpoint: &str) -> Result<T>
386    where
387        T: serde::de::DeserializeOwned,
388    {
389        let response = dbg!(self
390            .sync_client
391            .get(&format!("{}{}", self.base_url, endpoint)))
392        .call();
393        if let 200 = response.status() {
394            Ok(response
395                .into_json_deserialize()
396                .expect("Bug: client couldn't deserialize api response"))
397        } else {
398            let err = response
399                .into_json_deserialize::<api::ErrorWrapper>()
400                .expect("Bug: client couldn't deserialize api error response")
401                .error;
402            Err(Error::APIError(err))
403        }
404    }
405
406    /// Lists the currently available engines.
407    ///
408    /// Provides basic information about each one such as the owner and availability.
409    ///
410    /// # Errors
411    /// - `Error::APIError` if the server returns an error
412    #[cfg(feature = "async")]
413    pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
414        self.get("engines").await.map(|r: api::Container<_>| r.data)
415    }
416
417    /// Lists the currently available engines.
418    ///
419    /// Provides basic information about each one such as the owner and availability.
420    ///
421    /// # Errors
422    /// - `Error::APIError` if the server returns an error
423    #[cfg(feature = "sync")]
424    pub fn engines_sync(&self) -> Result<Vec<api::EngineInfo>> {
425        self.get_sync("engines").map(|r: api::Container<_>| r.data)
426    }
427
428    /// Retrieves an engine instance
429    ///
430    /// Provides basic information about the engine such as the owner and availability.
431    ///
432    /// # Errors
433    /// - `Error::APIError` if the server returns an error
434    #[cfg(feature = "async")]
435    pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
436        self.get(&format!("engines/{}", engine)).await
437    }
438
439    #[cfg(feature = "sync")]
440    pub fn engine_sync(&self, engine: api::Engine) -> Result<api::EngineInfo> {
441        self.get_sync(&format!("engines/{}", engine))
442    }
443
444    // Private helper to generate post requests. Needs to be a bit more flexible than
445    // get because it should support SSE eventually
446    #[cfg(feature = "async")]
447    async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
448    where
449        B: serde::ser::Serialize,
450        R: serde::de::DeserializeOwned,
451    {
452        let mut response = self
453            .async_client
454            .post(endpoint)
455            .body(surf::Body::from_json(&body)?)
456            .await?;
457        match response.status() {
458            surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
459            _ => Err(Error::APIError(
460                response
461                    .body_json::<api::ErrorWrapper>()
462                    .await
463                    .expect("The API has returned something funky")
464                    .error,
465            )),
466        }
467    }
468
469    #[cfg(feature = "sync")]
470    fn post_sync<B, R>(&self, endpoint: &str, body: B) -> Result<R>
471    where
472        B: serde::ser::Serialize,
473        R: serde::de::DeserializeOwned,
474    {
475        let response = self
476            .sync_client
477            .post(&format!("{}{}", self.base_url, endpoint))
478            .send_json(
479                serde_json::to_value(body).expect("Bug: client couldn't serialize its own type"),
480            );
481        match response.status() {
482            200 => Ok(response
483                .into_json_deserialize()
484                .expect("Bug: client couldn't deserialize api response")),
485            _ => Err(Error::APIError(
486                response
487                    .into_json_deserialize::<api::ErrorWrapper>()
488                    .expect("Bug: client couldn't deserialize api error response")
489                    .error,
490            )),
491        }
492    }
493
494    /// Get predicted completion of the prompt
495    ///
496    /// # Errors
497    ///  - `Error::APIError` if the api returns an error
498    #[cfg(feature = "async")]
499    pub async fn complete_prompt(
500        &self,
501        prompt: impl Into<api::CompletionArgs>,
502    ) -> Result<api::Completion> {
503        let args = prompt.into();
504        Ok(self
505            .post(&format!("engines/{}/completions", args.engine), args)
506            .await?)
507    }
508
509    /// Get predicted completion of the prompt synchronously
510    ///
511    /// # Error
512    /// - `Error::APIError` if the api returns an error
513    #[cfg(feature = "sync")]
514    pub fn complete_prompt_sync(
515        &self,
516        prompt: impl Into<api::CompletionArgs>,
517    ) -> Result<api::Completion> {
518        let args = prompt.into();
519        self.post_sync(&format!("engines/{}/completions", args.engine), args)
520    }
521}
522
523// TODO: add a macro to de-boilerplate the sync and async tests
524
525#[allow(unused_macros)]
526macro_rules! async_test {
527    ($test_name: ident, $test_body: block) => {
528        #[cfg(feature = "async")]
529        #[tokio::test]
530        async fn $test_name() -> crate::Result<()> {
531            $test_body;
532            Ok(())
533        }
534    };
535}
536
537#[allow(unused_macros)]
538macro_rules! sync_test {
539    ($test_name: ident, $test_body: expr) => {
540        #[cfg(feature = "sync")]
541        #[test]
542        fn $test_name() -> crate::Result<()> {
543            $test_body;
544            Ok(())
545        }
546    };
547}
548
549#[cfg(test)]
550mod unit {
551
552    use mockito::Mock;
553
554    use crate::{
555        api::{self, Completion, CompletionArgs, Engine, EngineInfo},
556        Client, Error,
557    };
558
559    fn mocked_client() -> Client {
560        let _ = env_logger::builder().is_test(true).try_init();
561        Client::new("bogus").set_api_root(&format!("{}/", mockito::server_url()))
562    }
563
564    #[test]
565    fn can_create_client() {
566        let _c = mocked_client();
567    }
568
569    #[test]
570    fn parse_engine_info() -> Result<(), Box<dyn std::error::Error>> {
571        let example = r#"{
572            "id": "ada",
573            "object": "engine",
574            "owner": "openai",
575            "ready": true
576        }"#;
577        let ei: api::EngineInfo = serde_json::from_str(example)?;
578        assert_eq!(
579            ei,
580            api::EngineInfo {
581                id: api::Engine::Ada,
582                owner: "openai".into(),
583                ready: true,
584            }
585        );
586        Ok(())
587    }
588
589    fn mock_engines() -> (Mock, Vec<EngineInfo>) {
590        let mock = mockito::mock("GET", "/engines")
591            .with_status(200)
592            .with_header("content-type", "application/json")
593            .with_body(
594                r#"{
595            "object": "list",
596            "data": [
597              {
598                "id": "ada",
599                "object": "engine",
600                "owner": "openai",
601                "ready": true
602              },
603              {
604                "id": "babbage",
605                "object": "engine",
606                "owner": "openai",
607                "ready": true
608              },
609              {
610                "id": "experimental-engine-v7",
611                "object": "engine",
612                "owner": "openai",
613                "ready": false
614              },
615              {
616                "id": "curie",
617                "object": "engine",
618                "owner": "openai",
619                "ready": true
620              },
621              {
622                "id": "davinci",
623                "object": "engine",
624                "owner": "openai",
625                "ready": true
626              },
627              {
628                 "id": "content-filter-alpha-c4",
629                 "object": "engine",
630                 "owner": "openai",
631                 "ready": true
632              }
633            ]
634          }"#,
635            )
636            .create();
637
638        let expected = vec![
639            EngineInfo {
640                id: Engine::Ada,
641                owner: "openai".into(),
642                ready: true,
643            },
644            EngineInfo {
645                id: Engine::Babbage,
646                owner: "openai".into(),
647                ready: true,
648            },
649            EngineInfo {
650                id: Engine::Other,
651                owner: "openai".into(),
652                ready: false,
653            },
654            EngineInfo {
655                id: Engine::Curie,
656                owner: "openai".into(),
657                ready: true,
658            },
659            EngineInfo {
660                id: Engine::Davinci,
661                owner: "openai".into(),
662                ready: true,
663            },
664            EngineInfo {
665                id: Engine::ContentFilter,
666                owner: "openai".into(),
667                ready: true,
668            },
669        ];
670        (mock, expected)
671    }
672
673    async_test!(parse_engines_async, {
674        let (_m, expected) = mock_engines();
675        let response = mocked_client().engines().await?;
676        assert_eq!(response, expected);
677    });
678
679    sync_test!(parse_engines_sync, {
680        let (_m, expected) = mock_engines();
681        let response = mocked_client().engines_sync()?;
682        assert_eq!(response, expected);
683    });
684
685    fn mock_engine() -> (Mock, api::ErrorMessage) {
686        let mock = mockito::mock("GET", "/engines/davinci")
687            .with_status(404)
688            .with_header("content-type", "application/json")
689            .with_body(
690                r#"{
691                "error": {
692                    "code": null,
693                    "message": "Some kind of error happened",
694                    "type": "some_error_type"
695                }
696            }"#,
697            )
698            .create();
699        let expected = api::ErrorMessage {
700            message: "Some kind of error happened".into(),
701            error_type: "some_error_type".into(),
702        };
703        (mock, expected)
704    }
705
706    async_test!(engine_error_response_async, {
707        let (_m, expected) = mock_engine();
708        let response = mocked_client().engine(api::Engine::Davinci).await;
709        if let Result::Err(Error::APIError(msg)) = response {
710            assert_eq!(expected, msg);
711        }
712    });
713
714    sync_test!(engine_error_response_sync, {
715        let (_m, expected) = mock_engine();
716        let response = mocked_client().engine_sync(api::Engine::Davinci);
717        if let Result::Err(Error::APIError(msg)) = response {
718            assert_eq!(expected, msg);
719        }
720    });
721    fn mock_completion() -> crate::Result<(Mock, CompletionArgs, Completion)> {
722        let mock = mockito::mock("POST", "/engines/davinci/completions")
723            .with_status(200)
724            .with_header("content-type", "application/json")
725            .with_body(
726                r#"{
727                "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
728                "object": "text_completion",
729                "created": 1589478378,
730                "model": "davinci:2020-05-03",
731                "choices": [
732                    {
733                    "text": " there was a girl who",
734                    "index": 0,
735                    "logprobs": null,
736                    "finish_reason": "length"
737                    }
738                ]
739                }"#,
740            )
741            .expect(1)
742            .create();
743        let args = api::CompletionArgs::builder()
744            .engine(api::Engine::Davinci)
745            .prompt("Once upon a time")
746            .max_tokens(5)
747            .temperature(1.0)
748            .top_p(1.0)
749            .n(1)
750            .stop(vec!["\n".into()])
751            .build()?;
752        let expected = api::Completion {
753            id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(),
754            created: 1589478378,
755            model: "davinci:2020-05-03".into(),
756            choices: vec![api::Choice {
757                text: " there was a girl who".into(),
758                index: 0,
759                logprobs: None,
760                finish_reason: api::FinishReason::MaxTokensReached,
761            }],
762        };
763        Ok((mock, args, expected))
764    }
765
766    // Defines boilerplate here. The Completion can't derive Eq since it contains
767    // floats in various places.
768    fn assert_completion_equal(a: Completion, b: Completion) {
769        assert_eq!(a.model, b.model);
770        assert_eq!(a.id, b.id);
771        assert_eq!(a.created, b.created);
772        let (a_choice, b_choice) = (&a.choices[0], &b.choices[0]);
773        assert_eq!(a_choice.text, b_choice.text);
774        assert_eq!(a_choice.index, b_choice.index);
775        assert!(a_choice.logprobs.is_none());
776        assert_eq!(a_choice.finish_reason, b_choice.finish_reason);
777    }
778
779    async_test!(completion_args_async, {
780        let (m, args, expected) = mock_completion()?;
781        let response = mocked_client().complete_prompt(args).await?;
782        assert_completion_equal(response, expected);
783        m.assert();
784    });
785
786    sync_test!(completion_args_sync, {
787        let (m, args, expected) = mock_completion()?;
788        let response = mocked_client().complete_prompt_sync(args)?;
789        assert_completion_equal(response, expected);
790        m.assert();
791    });
792}
793#[cfg(test)]
794mod integration {
795    use crate::{
796        api::{self, Completion},
797        Client, Error,
798    };
799    /// Used by tests to get a client to the actual api
800    fn get_client() -> Client {
801        let _ = env_logger::builder().is_test(true).try_init();
802        let sk = std::env::var("OPENAI_SK").expect(
803            "To run integration tests, you must put set the OPENAI_SK env var to your api token",
804        );
805        Client::new(&sk)
806    }
807
808    async_test!(can_get_engines_async, {
809        let client = get_client();
810        client.engines().await?
811    });
812
813    sync_test!(can_get_engines_sync, {
814        let client = get_client();
815        let engines = client
816            .engines_sync()?
817            .into_iter()
818            .map(|ei| ei.id)
819            .collect::<Vec<_>>();
820        assert!(engines.contains(&api::Engine::Ada));
821        assert!(engines.contains(&api::Engine::Babbage));
822        assert!(engines.contains(&api::Engine::Curie));
823        assert!(engines.contains(&api::Engine::Davinci));
824    });
825
826    fn assert_expected_engine_failure<T>(result: Result<T, Error>)
827    where
828        T: std::fmt::Debug,
829    {
830        match result {
831            Err(Error::APIError(api::ErrorMessage {
832                message,
833                error_type,
834            })) => {
835                assert_eq!(message, "No engine with that ID: ada");
836                assert_eq!(error_type, "invalid_request_error");
837            }
838            _ => {
839                panic!("Expected an error message, got {:?}", result)
840            }
841        }
842    }
843    async_test!(can_get_engine_async, {
844        let client = get_client();
845        assert_expected_engine_failure(client.engine(api::Engine::Ada).await);
846    });
847
848    sync_test!(can_get_engine_sync, {
849        let client = get_client();
850        assert_expected_engine_failure(client.engine_sync(api::Engine::Ada));
851    });
852
853    async_test!(complete_string_async, {
854        let client = get_client();
855        client.complete_prompt("Hey there").await?;
856    });
857
858    sync_test!(complete_string_sync, {
859        let client = get_client();
860        client.complete_prompt_sync("Hey there")?;
861    });
862
863    fn create_args() -> api::CompletionArgs {
864        api::CompletionArgsBuilder::default()
865            .prompt("Once upon a time,")
866            .max_tokens(10)
867            .temperature(0.5)
868            .top_p(0.5)
869            .n(1)
870            .logprobs(3)
871            .echo(false)
872            .stop(vec!["\n".into()])
873            .presence_penalty(0.5)
874            .frequency_penalty(0.5)
875            .logit_bias(maplit::hashmap! {
876                "1".into() => 1.0,
877                "23".into() => 0.0,
878            })
879            .build()
880            .expect("Bug: build should succeed")
881    }
882    async_test!(complete_explicit_params_async, {
883        let client = get_client();
884        let args = create_args();
885        client.complete_prompt(args).await?;
886    });
887
888    sync_test!(complete_explicit_params_sync, {
889        let client = get_client();
890        let args = create_args();
891        client.complete_prompt_sync(args)?
892    });
893
894    fn stop_condition_args() -> api::CompletionArgs {
895        let mut args = api::CompletionArgs::builder();
896        args.prompt(
897            r#"
898Q: Please type `#` now
899A:"#,
900        )
901        // turn temp & top_p way down to prevent test flakiness
902        .temperature(0.0)
903        .top_p(0.0)
904        .max_tokens(100)
905        .stop(vec!["#".into(), "\n".into()])
906        .build()
907        .expect("Bug: build should succeed")
908    }
909
910    fn assert_completion_finish_reason(completion: Completion) {
911        assert_eq!(
912            completion.choices[0].finish_reason,
913            api::FinishReason::StopSequenceReached
914        );
915    }
916
917    async_test!(complete_stop_condition_async, {
918        let client = get_client();
919        let args = stop_condition_args();
920        assert_completion_finish_reason(client.complete_prompt(args).await?);
921    });
922
923    sync_test!(complete_stop_condition_sync, {
924        let client = get_client();
925        let args = stop_condition_args();
926        assert_completion_finish_reason(client.complete_prompt_sync(args)?);
927    });
928}