1#[macro_use]
3extern crate derive_builder;
4
5use thiserror::Error;
6
7type Result<T> = std::result::Result<T, Error>;
8
9#[allow(clippy::default_trait_access)]
10pub mod api {
11 use std::{collections::HashMap, fmt::Display};
13
14 use super::Client;
15 use serde::{Deserialize, Serialize};
16
17 #[derive(Deserialize, Debug)]
19 pub(super) struct Container<T> {
20 pub data: Vec<T>,
21 }
22
23 #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
25 pub struct EngineInfo {
26 pub id: Engine,
27 pub owner: String,
28 pub ready: bool,
29 }
30
31 #[derive(Deserialize, Serialize, Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
33 #[serde(rename_all = "kebab-case")]
34 #[non_exhaustive] pub enum Engine {
36 Ada,
37 Babbage,
38 Curie,
39 Davinci,
40 #[serde(rename = "content-filter-alpha-c4")]
41 ContentFilter,
42 #[serde(other)]
43 Other,
44 }
45
46 impl std::fmt::Display for Engine {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 Engine::Ada => f.write_str("ada"),
51 Engine::Babbage => f.write_str("babbage"),
52 Engine::Curie => f.write_str("curie"),
53 Engine::Davinci => f.write_str("davinci"),
54 Engine::ContentFilter => f.write_str("content-filter-alpha-c4"),
55 _ => panic!("Can't write out Other engine id"),
56 }
57 }
58 }
59
60 #[derive(Serialize, Debug, Builder, Clone)]
62 pub struct CompletionArgs {
63 #[builder(setter(into), default = "\"<|endoftext|>\".into()")]
64 prompt: String,
65 #[builder(default = "Engine::Davinci")]
66 #[serde(skip_serializing)]
67 pub(super) engine: Engine,
68 #[builder(default = "16")]
69 max_tokens: u64,
70 #[builder(default = "1.0")]
71 temperature: f64,
72 #[builder(default = "1.0")]
73 top_p: f64,
74 #[builder(default = "1")]
75 n: u64,
76 #[builder(setter(strip_option), default)]
77 logprobs: Option<u64>,
78 #[builder(default = "false")]
79 echo: bool,
80 #[builder(setter(strip_option), default)]
81 stop: Option<Vec<String>>,
82 #[builder(default = "0.0")]
83 presence_penalty: f64,
84 #[builder(default = "0.0")]
85 frequency_penalty: f64,
86 #[builder(default)]
87 logit_bias: HashMap<String, f64>,
88 }
89
90 impl From<&str> for CompletionArgs {
98 fn from(prompt_string: &str) -> Self {
99 Self {
100 prompt: prompt_string.into(),
101 ..CompletionArgsBuilder::default()
102 .build()
103 .expect("default should build")
104 }
105 }
106 }
107
108 impl CompletionArgs {
109 #[must_use]
111 pub fn builder() -> CompletionArgsBuilder {
112 CompletionArgsBuilder::default()
113 }
114
115 #[cfg(feature = "async")]
120 pub async fn complete_prompt(self, client: &Client) -> super::Result<Completion> {
121 client.complete_prompt(self).await
122 }
123
124 #[cfg(feature = "sync")]
125 pub fn complete_prompt_sync(self, client: &Client) -> super::Result<Completion> {
126 client.complete_prompt_sync(self)
127 }
128 }
129
130 impl CompletionArgsBuilder {
131 #[cfg(feature = "async")]
137 pub async fn complete_prompt(&self, client: &Client) -> super::Result<Completion> {
138 client.complete_prompt(self.build()?).await
139 }
140
141 #[cfg(feature = "sync")]
142 pub fn complete_prompt_sync(&self, client: &Client) -> super::Result<Completion> {
143 client.complete_prompt_sync(self.build()?)
144 }
145 }
146
147 #[derive(Deserialize, Debug, Clone)]
149 pub struct Completion {
150 pub id: String,
152 pub created: u64,
154 pub model: String,
156 pub choices: Vec<Choice>,
158 }
159
160 impl std::fmt::Display for Completion {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 write!(f, "{}", self.choices[0])
163 }
164 }
165
166 #[derive(Deserialize, Debug, Clone)]
168 pub struct Choice {
169 pub text: String,
171 pub index: u64,
173 pub logprobs: Option<LogProbs>,
175 pub finish_reason: FinishReason,
177 }
178
179 impl std::fmt::Display for Choice {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 self.text.fmt(f)
182 }
183 }
184
185 #[derive(Deserialize, Debug, Clone)]
187 pub struct LogProbs {
188 pub tokens: Vec<String>,
189 pub token_logprobs: Vec<Option<f64>>,
190 pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
191 pub text_offset: Vec<u64>,
192 }
193
194 #[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
196 #[non_exhaustive]
197 pub enum FinishReason {
198 #[serde(rename = "length")]
200 MaxTokensReached,
201 #[serde(rename = "stop")]
203 StopSequenceReached,
204 }
205
206 #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
208 pub struct ErrorMessage {
209 pub message: String,
210 #[serde(rename = "type")]
211 pub error_type: String,
212 }
213
214 impl Display for ErrorMessage {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 self.message.fmt(f)
217 }
218 }
219
220 #[derive(Deserialize, Debug)]
222 pub(crate) struct ErrorWrapper {
223 pub error: ErrorMessage,
224 }
225}
226
227#[derive(Error, Debug)]
229pub enum Error {
230 #[error("API returned an Error: {}", .0.message)]
232 APIError(api::ErrorMessage),
233 #[error("Bad arguments: {0}")]
235 BadArguments(String),
236 #[cfg(feature = "async")]
238 #[error("Error at the protocol level: {0}")]
239 AsyncProtocolError(surf::Error),
240 #[cfg(feature = "sync")]
241 #[error("Error at the protocol level, sync client")]
242 SyncProtocolError(ureq::Error),
243}
244
245impl From<api::ErrorMessage> for Error {
246 fn from(e: api::ErrorMessage) -> Self {
247 Error::APIError(e)
248 }
249}
250
251impl From<String> for Error {
252 fn from(e: String) -> Self {
253 Error::BadArguments(e)
254 }
255}
256
257#[cfg(feature = "async")]
258impl From<surf::Error> for Error {
259 fn from(e: surf::Error) -> Self {
260 Error::AsyncProtocolError(e)
261 }
262}
263
264#[cfg(feature = "sync")]
265impl From<ureq::Error> for Error {
266 fn from(e: ureq::Error) -> Self {
267 Error::SyncProtocolError(e)
268 }
269}
270
271#[cfg(feature = "async")]
273struct BearerToken {
274 token: String,
275}
276
277#[cfg(feature = "async")]
278impl std::fmt::Debug for BearerToken {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 write!(
282 f,
283 r#"Bearer {{ token: "{}" }}"#,
284 self.token.get(0..8).ok_or(std::fmt::Error)?
285 )
286 }
287}
288
289#[cfg(feature = "async")]
290impl BearerToken {
291 fn new(token: &str) -> Self {
292 Self {
293 token: String::from(token),
294 }
295 }
296}
297
298#[cfg(feature = "async")]
299#[surf::utils::async_trait]
300impl surf::middleware::Middleware for BearerToken {
301 async fn handle(
302 &self,
303 mut req: surf::Request,
304 client: surf::Client,
305 next: surf::middleware::Next<'_>,
306 ) -> surf::Result<surf::Response> {
307 log::debug!("Request: {:?}", req);
308 req.insert_header("Authorization", format!("Bearer {}", self.token));
309 let response: surf::Response = next.run(req, client).await?;
310 log::debug!("Response: {:?}", response);
311 Ok(response)
312 }
313}
314
315#[cfg(feature = "async")]
316fn async_client(token: &str, base_url: &str) -> surf::Client {
317 let mut async_client = surf::client();
318 async_client.set_base_url(surf::Url::parse(base_url).expect("Static string should parse"));
319 async_client.with(BearerToken::new(token))
320}
321
322#[cfg(feature = "sync")]
323fn sync_client(token: &str) -> ureq::Agent {
324 ureq::agent().auth_kind("Bearer", token).build()
325}
326
327#[derive(Debug, Clone)]
329pub struct Client {
330 #[cfg(feature = "async")]
331 async_client: surf::Client,
332 #[cfg(feature = "sync")]
333 sync_client: ureq::Agent,
334 #[cfg(feature = "sync")]
335 base_url: String,
336}
337
338impl Client {
339 #[must_use]
341 pub fn new(token: &str) -> Self {
342 let base_url = String::from("https://api.openai.com/v1/");
343 Self {
344 #[cfg(feature = "async")]
345 async_client: async_client(token, &base_url),
346 #[cfg(feature = "sync")]
347 sync_client: sync_client(token),
348 #[cfg(feature = "sync")]
349 base_url,
350 }
351 }
352
353 #[cfg(test)]
355 fn set_api_root(mut self, base_url: &str) -> Self {
356 #[cfg(feature = "async")]
357 {
358 self.async_client.set_base_url(
359 surf::Url::parse(base_url).expect("static URL expected to parse correctly"),
360 );
361 }
362 #[cfg(feature = "sync")]
363 {
364 self.base_url = String::from(base_url);
365 }
366 self
367 }
368
369 #[cfg(feature = "async")]
371 async fn get<T>(&self, endpoint: &str) -> Result<T>
372 where
373 T: serde::de::DeserializeOwned,
374 {
375 let mut response = self.async_client.get(endpoint).await?;
376 if let surf::StatusCode::Ok = response.status() {
377 Ok(response.body_json::<T>().await?)
378 } else {
379 let err = response.body_json::<api::ErrorWrapper>().await?.error;
380 Err(Error::APIError(err))
381 }
382 }
383
384 #[cfg(feature = "sync")]
385 fn get_sync<T>(&self, endpoint: &str) -> Result<T>
386 where
387 T: serde::de::DeserializeOwned,
388 {
389 let response = dbg!(self
390 .sync_client
391 .get(&format!("{}{}", self.base_url, endpoint)))
392 .call();
393 if let 200 = response.status() {
394 Ok(response
395 .into_json_deserialize()
396 .expect("Bug: client couldn't deserialize api response"))
397 } else {
398 let err = response
399 .into_json_deserialize::<api::ErrorWrapper>()
400 .expect("Bug: client couldn't deserialize api error response")
401 .error;
402 Err(Error::APIError(err))
403 }
404 }
405
406 #[cfg(feature = "async")]
413 pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
414 self.get("engines").await.map(|r: api::Container<_>| r.data)
415 }
416
417 #[cfg(feature = "sync")]
424 pub fn engines_sync(&self) -> Result<Vec<api::EngineInfo>> {
425 self.get_sync("engines").map(|r: api::Container<_>| r.data)
426 }
427
428 #[cfg(feature = "async")]
435 pub async fn engine(&self, engine: api::Engine) -> Result<api::EngineInfo> {
436 self.get(&format!("engines/{}", engine)).await
437 }
438
439 #[cfg(feature = "sync")]
440 pub fn engine_sync(&self, engine: api::Engine) -> Result<api::EngineInfo> {
441 self.get_sync(&format!("engines/{}", engine))
442 }
443
444 #[cfg(feature = "async")]
447 async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
448 where
449 B: serde::ser::Serialize,
450 R: serde::de::DeserializeOwned,
451 {
452 let mut response = self
453 .async_client
454 .post(endpoint)
455 .body(surf::Body::from_json(&body)?)
456 .await?;
457 match response.status() {
458 surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
459 _ => Err(Error::APIError(
460 response
461 .body_json::<api::ErrorWrapper>()
462 .await
463 .expect("The API has returned something funky")
464 .error,
465 )),
466 }
467 }
468
469 #[cfg(feature = "sync")]
470 fn post_sync<B, R>(&self, endpoint: &str, body: B) -> Result<R>
471 where
472 B: serde::ser::Serialize,
473 R: serde::de::DeserializeOwned,
474 {
475 let response = self
476 .sync_client
477 .post(&format!("{}{}", self.base_url, endpoint))
478 .send_json(
479 serde_json::to_value(body).expect("Bug: client couldn't serialize its own type"),
480 );
481 match response.status() {
482 200 => Ok(response
483 .into_json_deserialize()
484 .expect("Bug: client couldn't deserialize api response")),
485 _ => Err(Error::APIError(
486 response
487 .into_json_deserialize::<api::ErrorWrapper>()
488 .expect("Bug: client couldn't deserialize api error response")
489 .error,
490 )),
491 }
492 }
493
494 #[cfg(feature = "async")]
499 pub async fn complete_prompt(
500 &self,
501 prompt: impl Into<api::CompletionArgs>,
502 ) -> Result<api::Completion> {
503 let args = prompt.into();
504 Ok(self
505 .post(&format!("engines/{}/completions", args.engine), args)
506 .await?)
507 }
508
509 #[cfg(feature = "sync")]
514 pub fn complete_prompt_sync(
515 &self,
516 prompt: impl Into<api::CompletionArgs>,
517 ) -> Result<api::Completion> {
518 let args = prompt.into();
519 self.post_sync(&format!("engines/{}/completions", args.engine), args)
520 }
521}
522
523#[allow(unused_macros)]
526macro_rules! async_test {
527 ($test_name: ident, $test_body: block) => {
528 #[cfg(feature = "async")]
529 #[tokio::test]
530 async fn $test_name() -> crate::Result<()> {
531 $test_body;
532 Ok(())
533 }
534 };
535}
536
537#[allow(unused_macros)]
538macro_rules! sync_test {
539 ($test_name: ident, $test_body: expr) => {
540 #[cfg(feature = "sync")]
541 #[test]
542 fn $test_name() -> crate::Result<()> {
543 $test_body;
544 Ok(())
545 }
546 };
547}
548
549#[cfg(test)]
550mod unit {
551
552 use mockito::Mock;
553
554 use crate::{
555 api::{self, Completion, CompletionArgs, Engine, EngineInfo},
556 Client, Error,
557 };
558
559 fn mocked_client() -> Client {
560 let _ = env_logger::builder().is_test(true).try_init();
561 Client::new("bogus").set_api_root(&format!("{}/", mockito::server_url()))
562 }
563
564 #[test]
565 fn can_create_client() {
566 let _c = mocked_client();
567 }
568
569 #[test]
570 fn parse_engine_info() -> Result<(), Box<dyn std::error::Error>> {
571 let example = r#"{
572 "id": "ada",
573 "object": "engine",
574 "owner": "openai",
575 "ready": true
576 }"#;
577 let ei: api::EngineInfo = serde_json::from_str(example)?;
578 assert_eq!(
579 ei,
580 api::EngineInfo {
581 id: api::Engine::Ada,
582 owner: "openai".into(),
583 ready: true,
584 }
585 );
586 Ok(())
587 }
588
589 fn mock_engines() -> (Mock, Vec<EngineInfo>) {
590 let mock = mockito::mock("GET", "/engines")
591 .with_status(200)
592 .with_header("content-type", "application/json")
593 .with_body(
594 r#"{
595 "object": "list",
596 "data": [
597 {
598 "id": "ada",
599 "object": "engine",
600 "owner": "openai",
601 "ready": true
602 },
603 {
604 "id": "babbage",
605 "object": "engine",
606 "owner": "openai",
607 "ready": true
608 },
609 {
610 "id": "experimental-engine-v7",
611 "object": "engine",
612 "owner": "openai",
613 "ready": false
614 },
615 {
616 "id": "curie",
617 "object": "engine",
618 "owner": "openai",
619 "ready": true
620 },
621 {
622 "id": "davinci",
623 "object": "engine",
624 "owner": "openai",
625 "ready": true
626 },
627 {
628 "id": "content-filter-alpha-c4",
629 "object": "engine",
630 "owner": "openai",
631 "ready": true
632 }
633 ]
634 }"#,
635 )
636 .create();
637
638 let expected = vec![
639 EngineInfo {
640 id: Engine::Ada,
641 owner: "openai".into(),
642 ready: true,
643 },
644 EngineInfo {
645 id: Engine::Babbage,
646 owner: "openai".into(),
647 ready: true,
648 },
649 EngineInfo {
650 id: Engine::Other,
651 owner: "openai".into(),
652 ready: false,
653 },
654 EngineInfo {
655 id: Engine::Curie,
656 owner: "openai".into(),
657 ready: true,
658 },
659 EngineInfo {
660 id: Engine::Davinci,
661 owner: "openai".into(),
662 ready: true,
663 },
664 EngineInfo {
665 id: Engine::ContentFilter,
666 owner: "openai".into(),
667 ready: true,
668 },
669 ];
670 (mock, expected)
671 }
672
673 async_test!(parse_engines_async, {
674 let (_m, expected) = mock_engines();
675 let response = mocked_client().engines().await?;
676 assert_eq!(response, expected);
677 });
678
679 sync_test!(parse_engines_sync, {
680 let (_m, expected) = mock_engines();
681 let response = mocked_client().engines_sync()?;
682 assert_eq!(response, expected);
683 });
684
685 fn mock_engine() -> (Mock, api::ErrorMessage) {
686 let mock = mockito::mock("GET", "/engines/davinci")
687 .with_status(404)
688 .with_header("content-type", "application/json")
689 .with_body(
690 r#"{
691 "error": {
692 "code": null,
693 "message": "Some kind of error happened",
694 "type": "some_error_type"
695 }
696 }"#,
697 )
698 .create();
699 let expected = api::ErrorMessage {
700 message: "Some kind of error happened".into(),
701 error_type: "some_error_type".into(),
702 };
703 (mock, expected)
704 }
705
706 async_test!(engine_error_response_async, {
707 let (_m, expected) = mock_engine();
708 let response = mocked_client().engine(api::Engine::Davinci).await;
709 if let Result::Err(Error::APIError(msg)) = response {
710 assert_eq!(expected, msg);
711 }
712 });
713
714 sync_test!(engine_error_response_sync, {
715 let (_m, expected) = mock_engine();
716 let response = mocked_client().engine_sync(api::Engine::Davinci);
717 if let Result::Err(Error::APIError(msg)) = response {
718 assert_eq!(expected, msg);
719 }
720 });
721 fn mock_completion() -> crate::Result<(Mock, CompletionArgs, Completion)> {
722 let mock = mockito::mock("POST", "/engines/davinci/completions")
723 .with_status(200)
724 .with_header("content-type", "application/json")
725 .with_body(
726 r#"{
727 "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
728 "object": "text_completion",
729 "created": 1589478378,
730 "model": "davinci:2020-05-03",
731 "choices": [
732 {
733 "text": " there was a girl who",
734 "index": 0,
735 "logprobs": null,
736 "finish_reason": "length"
737 }
738 ]
739 }"#,
740 )
741 .expect(1)
742 .create();
743 let args = api::CompletionArgs::builder()
744 .engine(api::Engine::Davinci)
745 .prompt("Once upon a time")
746 .max_tokens(5)
747 .temperature(1.0)
748 .top_p(1.0)
749 .n(1)
750 .stop(vec!["\n".into()])
751 .build()?;
752 let expected = api::Completion {
753 id: "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7".into(),
754 created: 1589478378,
755 model: "davinci:2020-05-03".into(),
756 choices: vec![api::Choice {
757 text: " there was a girl who".into(),
758 index: 0,
759 logprobs: None,
760 finish_reason: api::FinishReason::MaxTokensReached,
761 }],
762 };
763 Ok((mock, args, expected))
764 }
765
766 fn assert_completion_equal(a: Completion, b: Completion) {
769 assert_eq!(a.model, b.model);
770 assert_eq!(a.id, b.id);
771 assert_eq!(a.created, b.created);
772 let (a_choice, b_choice) = (&a.choices[0], &b.choices[0]);
773 assert_eq!(a_choice.text, b_choice.text);
774 assert_eq!(a_choice.index, b_choice.index);
775 assert!(a_choice.logprobs.is_none());
776 assert_eq!(a_choice.finish_reason, b_choice.finish_reason);
777 }
778
779 async_test!(completion_args_async, {
780 let (m, args, expected) = mock_completion()?;
781 let response = mocked_client().complete_prompt(args).await?;
782 assert_completion_equal(response, expected);
783 m.assert();
784 });
785
786 sync_test!(completion_args_sync, {
787 let (m, args, expected) = mock_completion()?;
788 let response = mocked_client().complete_prompt_sync(args)?;
789 assert_completion_equal(response, expected);
790 m.assert();
791 });
792}
793#[cfg(test)]
794mod integration {
795 use crate::{
796 api::{self, Completion},
797 Client, Error,
798 };
799 fn get_client() -> Client {
801 let _ = env_logger::builder().is_test(true).try_init();
802 let sk = std::env::var("OPENAI_SK").expect(
803 "To run integration tests, you must put set the OPENAI_SK env var to your api token",
804 );
805 Client::new(&sk)
806 }
807
808 async_test!(can_get_engines_async, {
809 let client = get_client();
810 client.engines().await?
811 });
812
813 sync_test!(can_get_engines_sync, {
814 let client = get_client();
815 let engines = client
816 .engines_sync()?
817 .into_iter()
818 .map(|ei| ei.id)
819 .collect::<Vec<_>>();
820 assert!(engines.contains(&api::Engine::Ada));
821 assert!(engines.contains(&api::Engine::Babbage));
822 assert!(engines.contains(&api::Engine::Curie));
823 assert!(engines.contains(&api::Engine::Davinci));
824 });
825
826 fn assert_expected_engine_failure<T>(result: Result<T, Error>)
827 where
828 T: std::fmt::Debug,
829 {
830 match result {
831 Err(Error::APIError(api::ErrorMessage {
832 message,
833 error_type,
834 })) => {
835 assert_eq!(message, "No engine with that ID: ada");
836 assert_eq!(error_type, "invalid_request_error");
837 }
838 _ => {
839 panic!("Expected an error message, got {:?}", result)
840 }
841 }
842 }
843 async_test!(can_get_engine_async, {
844 let client = get_client();
845 assert_expected_engine_failure(client.engine(api::Engine::Ada).await);
846 });
847
848 sync_test!(can_get_engine_sync, {
849 let client = get_client();
850 assert_expected_engine_failure(client.engine_sync(api::Engine::Ada));
851 });
852
853 async_test!(complete_string_async, {
854 let client = get_client();
855 client.complete_prompt("Hey there").await?;
856 });
857
858 sync_test!(complete_string_sync, {
859 let client = get_client();
860 client.complete_prompt_sync("Hey there")?;
861 });
862
863 fn create_args() -> api::CompletionArgs {
864 api::CompletionArgsBuilder::default()
865 .prompt("Once upon a time,")
866 .max_tokens(10)
867 .temperature(0.5)
868 .top_p(0.5)
869 .n(1)
870 .logprobs(3)
871 .echo(false)
872 .stop(vec!["\n".into()])
873 .presence_penalty(0.5)
874 .frequency_penalty(0.5)
875 .logit_bias(maplit::hashmap! {
876 "1".into() => 1.0,
877 "23".into() => 0.0,
878 })
879 .build()
880 .expect("Bug: build should succeed")
881 }
882 async_test!(complete_explicit_params_async, {
883 let client = get_client();
884 let args = create_args();
885 client.complete_prompt(args).await?;
886 });
887
888 sync_test!(complete_explicit_params_sync, {
889 let client = get_client();
890 let args = create_args();
891 client.complete_prompt_sync(args)?
892 });
893
894 fn stop_condition_args() -> api::CompletionArgs {
895 let mut args = api::CompletionArgs::builder();
896 args.prompt(
897 r#"
898Q: Please type `#` now
899A:"#,
900 )
901 .temperature(0.0)
903 .top_p(0.0)
904 .max_tokens(100)
905 .stop(vec!["#".into(), "\n".into()])
906 .build()
907 .expect("Bug: build should succeed")
908 }
909
910 fn assert_completion_finish_reason(completion: Completion) {
911 assert_eq!(
912 completion.choices[0].finish_reason,
913 api::FinishReason::StopSequenceReached
914 );
915 }
916
917 async_test!(complete_stop_condition_async, {
918 let client = get_client();
919 let args = stop_condition_args();
920 assert_completion_finish_reason(client.complete_prompt(args).await?);
921 });
922
923 sync_test!(complete_stop_condition_sync, {
924 let client = get_client();
925 let args = stop_condition_args();
926 assert_completion_finish_reason(client.complete_prompt_sync(args)?);
927 });
928}