1#![cfg(feature = "async")]
9
10use crate::client::Client;
11use crate::dry_run::DryRun;
12use crate::error::Result;
13use crate::messages::request::{CountTokensRequest, CreateMessageRequest};
14use crate::messages::response::{CountTokensResponse, Message};
15
16#[cfg(feature = "streaming")]
17use crate::messages::stream::EventStream;
18
19pub struct Messages<'a> {
24 client: &'a Client,
25}
26
27impl<'a> Messages<'a> {
28 pub(crate) fn new(client: &'a Client) -> Self {
29 Self { client }
30 }
31
32 pub async fn create(&self, request: CreateMessageRequest) -> Result<Message> {
37 self.create_with_beta(request, &[]).await
38 }
39
40 pub async fn create_with_beta(
43 &self,
44 request: CreateMessageRequest,
45 betas: &[&str],
46 ) -> Result<Message> {
47 let request_ref = &request;
48 self.client
49 .execute_with_retry(
50 || {
51 self.client
52 .request_builder(reqwest::Method::POST, "/v1/messages")
53 .json(request_ref)
54 },
55 betas,
56 )
57 .await
58 }
59
60 pub async fn count_tokens(&self, request: CountTokensRequest) -> Result<CountTokensResponse> {
63 self.count_tokens_with_beta(request, &[]).await
64 }
65
66 pub async fn count_tokens_with_beta(
68 &self,
69 request: CountTokensRequest,
70 betas: &[&str],
71 ) -> Result<CountTokensResponse> {
72 let request_ref = &request;
73 self.client
74 .execute_with_retry(
75 || {
76 self.client
77 .request_builder(reqwest::Method::POST, "/v1/messages/count_tokens")
78 .json(request_ref)
79 },
80 betas,
81 )
82 .await
83 }
84
85 pub fn dry_run(&self, request: &CreateMessageRequest) -> Result<DryRun> {
90 self.dry_run_with_beta(request, &[])
91 }
92
93 pub fn dry_run_with_beta(
96 &self,
97 request: &CreateMessageRequest,
98 betas: &[&str],
99 ) -> Result<DryRun> {
100 let builder = self
101 .client
102 .request_builder(reqwest::Method::POST, "/v1/messages")
103 .json(request);
104 self.client.render_dry_run(builder, betas)
105 }
106
107 #[cfg(feature = "pricing")]
117 #[cfg_attr(docsrs, doc(cfg(feature = "pricing")))]
118 pub async fn cost_preview(
119 &self,
120 request: &CreateMessageRequest,
121 pricing: &crate::pricing::PricingTable,
122 ) -> Result<crate::cost_preview::CostPreview> {
123 use crate::types::Usage;
124 let count = self.count_tokens(CountTokensRequest::from(request)).await?;
125 let input_tokens = count.input_tokens;
126 let max_output_tokens = request.max_tokens;
127 let input_cost_usd = pricing.cost(
128 &request.model,
129 &Usage {
130 input_tokens,
131 output_tokens: 0,
132 ..Usage::default()
133 },
134 );
135 let max_total_usd = pricing.cost(
136 &request.model,
137 &Usage {
138 input_tokens,
139 output_tokens: max_output_tokens,
140 ..Usage::default()
141 },
142 );
143 let max_output_cost_usd = max_total_usd - input_cost_usd;
144 Ok(crate::cost_preview::CostPreview {
145 model: request.model.clone(),
146 input_tokens,
147 max_output_tokens,
148 input_cost_usd,
149 max_output_cost_usd,
150 max_total_usd,
151 })
152 }
153
154 #[cfg(feature = "pricing")]
160 #[cfg_attr(docsrs, doc(cfg(feature = "pricing")))]
161 pub async fn cost_preview_cached(
162 &self,
163 request: &CreateMessageRequest,
164 pricing: &crate::pricing::PricingTable,
165 cache: &crate::cost_preview::CountTokensCache,
166 ) -> Result<crate::cost_preview::CostPreview> {
167 use crate::types::Usage;
168 let count_req = CountTokensRequest::from(request);
169 let key = crate::cost_preview::hash_request(&count_req);
170
171 let input_tokens = if let Some(cached) = cache.get(key) {
172 cached
173 } else {
174 let count = self.count_tokens(count_req).await?;
175 cache.put(key, count.input_tokens);
176 count.input_tokens
177 };
178
179 let max_output_tokens = request.max_tokens;
180 let input_cost_usd = pricing.cost(
181 &request.model,
182 &Usage {
183 input_tokens,
184 output_tokens: 0,
185 ..Usage::default()
186 },
187 );
188 let max_total_usd = pricing.cost(
189 &request.model,
190 &Usage {
191 input_tokens,
192 output_tokens: max_output_tokens,
193 ..Usage::default()
194 },
195 );
196 let max_output_cost_usd = max_total_usd - input_cost_usd;
197 Ok(crate::cost_preview::CostPreview {
198 model: request.model.clone(),
199 input_tokens,
200 max_output_tokens,
201 input_cost_usd,
202 max_output_cost_usd,
203 max_total_usd,
204 })
205 }
206
207 pub fn dry_run_count_tokens(&self, request: &CountTokensRequest) -> Result<DryRun> {
209 self.dry_run_count_tokens_with_beta(request, &[])
210 }
211
212 pub fn dry_run_count_tokens_with_beta(
214 &self,
215 request: &CountTokensRequest,
216 betas: &[&str],
217 ) -> Result<DryRun> {
218 let builder = self
219 .client
220 .request_builder(reqwest::Method::POST, "/v1/messages/count_tokens")
221 .json(request);
222 self.client.render_dry_run(builder, betas)
223 }
224
225 #[cfg(feature = "streaming")]
234 #[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
235 pub async fn create_stream(&self, request: CreateMessageRequest) -> Result<EventStream> {
236 self.create_stream_with_beta(request, &[]).await
237 }
238
239 #[cfg(feature = "streaming")]
241 #[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
242 pub async fn create_stream_with_beta(
243 &self,
244 mut request: CreateMessageRequest,
245 betas: &[&str],
246 ) -> Result<EventStream> {
247 request.stream = true;
248 let response = self
249 .client
250 .execute_streaming(
251 self.client
252 .request_builder(reqwest::Method::POST, "/v1/messages")
253 .json(&request),
254 betas,
255 )
256 .await?;
257 Ok(EventStream::from_response(response))
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::messages::input::MessageInput;
265 use crate::messages::response::Message;
266 use crate::types::{ModelId, Role, StopReason};
267 use pretty_assertions::assert_eq;
268 use serde_json::json;
269 use wiremock::matchers::{body_partial_json, header, header_exists, method, path};
270 use wiremock::{Mock, MockServer, ResponseTemplate};
271
272 fn client_for(mock: &MockServer) -> Client {
273 Client::builder()
274 .api_key("sk-ant-test")
275 .base_url(mock.uri())
276 .build()
277 .unwrap()
278 }
279
280 fn fake_response_body() -> serde_json::Value {
281 json!({
282 "id": "msg_test",
283 "type": "message",
284 "role": "assistant",
285 "content": [{"type": "text", "text": "Hi!"}],
286 "model": "claude-sonnet-4-6",
287 "stop_reason": "end_turn",
288 "usage": {"input_tokens": 5, "output_tokens": 2}
289 })
290 }
291
292 #[tokio::test]
293 async fn create_posts_to_v1_messages_with_typed_request_body() {
294 let mock = MockServer::start().await;
295 Mock::given(method("POST"))
296 .and(path("/v1/messages"))
297 .and(header("x-api-key", "sk-ant-test"))
298 .and(header("anthropic-version", crate::ANTHROPIC_VERSION))
299 .and(body_partial_json(json!({
300 "model": "claude-sonnet-4-6",
301 "max_tokens": 64,
302 "messages": [{"role": "user", "content": "hi"}]
303 })))
304 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response_body()))
305 .mount(&mock)
306 .await;
307
308 let client = client_for(&mock);
309 let req = CreateMessageRequest::builder()
310 .model(ModelId::SONNET_4_6)
311 .max_tokens(64)
312 .user("hi")
313 .build()
314 .unwrap();
315 let resp = client.messages().create(req).await.unwrap();
316
317 assert_eq!(resp.id, "msg_test");
318 assert_eq!(resp.role, Role::Assistant);
319 assert_eq!(resp.stop_reason, Some(StopReason::EndTurn));
320 assert_eq!(resp.usage.input_tokens, 5);
321 }
322
323 #[tokio::test]
324 async fn create_with_beta_attaches_per_request_beta_header() {
325 let mock = MockServer::start().await;
326 Mock::given(method("POST"))
327 .and(path("/v1/messages"))
328 .and(header_exists("anthropic-beta"))
329 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response_body()))
330 .mount(&mock)
331 .await;
332
333 let client = client_for(&mock);
334 let req = CreateMessageRequest::builder()
335 .model(ModelId::SONNET_4_6)
336 .max_tokens(8)
337 .user("x")
338 .build()
339 .unwrap();
340
341 let _: Message = client
342 .messages()
343 .create_with_beta(req, &["computer-use-2025-01-24"])
344 .await
345 .unwrap();
346
347 let received = &mock.received_requests().await.unwrap()[0];
348 let beta = received
349 .headers
350 .get("anthropic-beta")
351 .unwrap()
352 .to_str()
353 .unwrap();
354 assert_eq!(beta, "computer-use-2025-01-24");
355 }
356
357 #[tokio::test]
358 async fn dry_run_renders_request_without_sending() {
359 let mock = MockServer::start().await;
361 let client = client_for(&mock);
362 let req = CreateMessageRequest::builder()
363 .model(ModelId::SONNET_4_6)
364 .max_tokens(64)
365 .user("hello")
366 .build()
367 .unwrap();
368
369 let dr = client.messages().dry_run(&req).unwrap();
370
371 assert_eq!(dr.method, reqwest::Method::POST);
372 assert_eq!(dr.url, format!("{}/v1/messages", mock.uri()));
373 assert_eq!(dr.headers.get("x-api-key").unwrap(), "sk-ant-test");
374 assert_eq!(
375 dr.headers.get("anthropic-version").unwrap(),
376 crate::ANTHROPIC_VERSION
377 );
378 assert_eq!(dr.body["model"], "claude-sonnet-4-6");
380 assert_eq!(dr.body["max_tokens"], 64);
381 assert_eq!(dr.body["messages"][0]["role"], "user");
382
383 assert_eq!(mock.received_requests().await.unwrap().len(), 0);
385 }
386
387 #[tokio::test]
388 async fn dry_run_with_beta_includes_anthropic_beta_header() {
389 let mock = MockServer::start().await;
390 let client = client_for(&mock);
391 let req = CreateMessageRequest::builder()
392 .model(ModelId::SONNET_4_6)
393 .max_tokens(8)
394 .user("x")
395 .build()
396 .unwrap();
397
398 let dr = client
399 .messages()
400 .dry_run_with_beta(&req, &["computer-use-2025-01-24"])
401 .unwrap();
402
403 assert_eq!(
404 dr.headers.get("anthropic-beta").unwrap(),
405 "computer-use-2025-01-24"
406 );
407 }
408
409 #[tokio::test]
410 async fn dry_run_count_tokens_uses_count_tokens_path() {
411 let mock = MockServer::start().await;
412 let client = client_for(&mock);
413 let req = CountTokensRequest::builder()
414 .model(ModelId::HAIKU_4_5)
415 .user("x")
416 .build()
417 .unwrap();
418
419 let dr = client.messages().dry_run_count_tokens(&req).unwrap();
420 assert!(dr.url.ends_with("/v1/messages/count_tokens"));
421 assert_eq!(dr.body["model"], "claude-haiku-4-5-20251001");
422 }
423
424 #[tokio::test]
425 async fn create_propagates_api_error_with_request_id() {
426 let mock = MockServer::start().await;
427 Mock::given(method("POST"))
428 .and(path("/v1/messages"))
429 .respond_with(
430 ResponseTemplate::new(400)
431 .insert_header("request-id", "req_xyz")
432 .set_body_json(json!({
433 "type": "error",
434 "error": {"type": "invalid_request_error", "message": "bad input"}
435 })),
436 )
437 .mount(&mock)
438 .await;
439
440 let client = client_for(&mock);
441 let req = CreateMessageRequest::builder()
442 .model(ModelId::SONNET_4_6)
443 .max_tokens(8)
444 .user("x")
445 .build()
446 .unwrap();
447
448 let err = client.messages().create(req).await.unwrap_err();
449 assert_eq!(err.request_id(), Some("req_xyz"));
450 assert_eq!(err.status(), Some(http::StatusCode::BAD_REQUEST));
451 }
452
453 #[tokio::test]
454 async fn count_tokens_posts_to_count_tokens_endpoint() {
455 let mock = MockServer::start().await;
456 Mock::given(method("POST"))
457 .and(path("/v1/messages/count_tokens"))
458 .and(body_partial_json(json!({
459 "model": "claude-haiku-4-5-20251001",
460 "messages": [{"role": "user", "content": "x"}]
461 })))
462 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 7})))
463 .mount(&mock)
464 .await;
465
466 let client = client_for(&mock);
467 let req = CountTokensRequest::builder()
468 .model(ModelId::HAIKU_4_5)
469 .user("x")
470 .build()
471 .unwrap();
472 let resp = client.messages().count_tokens(req).await.unwrap();
473 assert_eq!(resp.input_tokens, 7);
474 }
475
476 #[cfg(feature = "pricing")]
477 #[tokio::test]
478 async fn cost_preview_calls_count_tokens_and_computes_bounds() {
479 let mock = MockServer::start().await;
480 Mock::given(method("POST"))
481 .and(path("/v1/messages/count_tokens"))
482 .and(body_partial_json(json!({
483 "model": "claude-sonnet-4-6",
484 "messages": [{"role": "user", "content": "hi"}]
485 })))
486 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 1000})))
487 .mount(&mock)
488 .await;
489
490 let client = client_for(&mock);
491 let req = CreateMessageRequest::builder()
492 .model(ModelId::SONNET_4_6)
493 .max_tokens(2000)
494 .user("hi")
495 .build()
496 .unwrap();
497
498 let pricing = crate::pricing::PricingTable::default();
499 let preview = client
500 .messages()
501 .cost_preview(&req, &pricing)
502 .await
503 .unwrap();
504
505 assert_eq!(preview.input_tokens, 1000);
507 assert_eq!(preview.max_output_tokens, 2000);
508 assert!((preview.input_cost_usd - 0.003).abs() < 1e-9);
510 assert!((preview.max_output_cost_usd - 0.030).abs() < 1e-9);
512 assert!((preview.max_total_usd - 0.033).abs() < 1e-9);
514 }
515
516 #[cfg(feature = "pricing")]
517 #[tokio::test]
518 async fn cost_preview_cost_for_returns_point_estimate() {
519 let mock = MockServer::start().await;
520 Mock::given(method("POST"))
521 .and(path("/v1/messages/count_tokens"))
522 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 1000})))
523 .mount(&mock)
524 .await;
525
526 let client = client_for(&mock);
527 let req = CreateMessageRequest::builder()
528 .model(ModelId::SONNET_4_6)
529 .max_tokens(5000)
530 .user("hi")
531 .build()
532 .unwrap();
533
534 let pricing = crate::pricing::PricingTable::default();
535 let preview = client
536 .messages()
537 .cost_preview(&req, &pricing)
538 .await
539 .unwrap();
540 let estimate = preview.cost_for(500, &pricing);
542 assert!((estimate - 0.0105).abs() < 1e-9);
543 }
544
545 #[cfg(feature = "pricing")]
546 #[tokio::test]
547 async fn cost_preview_cached_skips_network_on_hit() {
548 let mock = MockServer::start().await;
549 Mock::given(method("POST"))
552 .and(path("/v1/messages/count_tokens"))
553 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 42})))
554 .expect(1)
555 .mount(&mock)
556 .await;
557
558 let client = client_for(&mock);
559 let req = CreateMessageRequest::builder()
560 .model(ModelId::SONNET_4_6)
561 .max_tokens(64)
562 .user("hi")
563 .build()
564 .unwrap();
565 let pricing = crate::pricing::PricingTable::default();
566 let cache = crate::cost_preview::CountTokensCache::new(8);
567
568 let p1 = client
569 .messages()
570 .cost_preview_cached(&req, &pricing, &cache)
571 .await
572 .unwrap();
573 let p2 = client
574 .messages()
575 .cost_preview_cached(&req, &pricing, &cache)
576 .await
577 .unwrap();
578
579 assert_eq!(p1, p2);
580 assert_eq!(cache.len(), 1);
581 }
582
583 #[cfg(feature = "pricing")]
584 #[tokio::test]
585 async fn cost_preview_cached_distinguishes_different_requests() {
586 let mock = MockServer::start().await;
587 Mock::given(method("POST"))
588 .and(path("/v1/messages/count_tokens"))
589 .and(body_partial_json(
590 json!({"messages": [{"role": "user", "content": "alpha"}]}),
591 ))
592 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 100})))
593 .mount(&mock)
594 .await;
595 Mock::given(method("POST"))
596 .and(path("/v1/messages/count_tokens"))
597 .and(body_partial_json(
598 json!({"messages": [{"role": "user", "content": "beta"}]}),
599 ))
600 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"input_tokens": 200})))
601 .mount(&mock)
602 .await;
603
604 let client = client_for(&mock);
605 let pricing = crate::pricing::PricingTable::default();
606 let cache = crate::cost_preview::CountTokensCache::new(8);
607
608 let req_a = CreateMessageRequest::builder()
609 .model(ModelId::SONNET_4_6)
610 .max_tokens(64)
611 .user("alpha")
612 .build()
613 .unwrap();
614 let req_b = CreateMessageRequest::builder()
615 .model(ModelId::SONNET_4_6)
616 .max_tokens(64)
617 .user("beta")
618 .build()
619 .unwrap();
620
621 let pa = client
622 .messages()
623 .cost_preview_cached(&req_a, &pricing, &cache)
624 .await
625 .unwrap();
626 let pb = client
627 .messages()
628 .cost_preview_cached(&req_b, &pricing, &cache)
629 .await
630 .unwrap();
631
632 assert_eq!(pa.input_tokens, 100);
633 assert_eq!(pb.input_tokens, 200);
634 assert_eq!(cache.len(), 2);
635 }
636
637 #[test]
638 fn count_tokens_request_from_create_drops_max_tokens_and_sampling() {
639 let req = CreateMessageRequest::builder()
640 .model(ModelId::SONNET_4_6)
641 .max_tokens(64)
642 .temperature(0.7)
643 .user("hello")
644 .build()
645 .unwrap();
646
647 let count_req = CountTokensRequest::from(&req);
648 assert_eq!(count_req.model, req.model);
649 assert_eq!(count_req.messages.len(), 1);
650 let body = serde_json::to_value(&count_req).unwrap();
652 assert!(body.get("max_tokens").is_none());
653 assert!(body.get("temperature").is_none());
654 }
655
656 #[tokio::test]
657 async fn create_appends_assistant_prefill_in_history() {
658 let mock = MockServer::start().await;
660 Mock::given(method("POST"))
661 .and(path("/v1/messages"))
662 .and(body_partial_json(json!({
663 "messages": [
664 {"role": "user", "content": "hi"},
665 {"role": "assistant", "content": "Sure, "}
666 ]
667 })))
668 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response_body()))
669 .mount(&mock)
670 .await;
671
672 let client = client_for(&mock);
673 let req = CreateMessageRequest::builder()
674 .model(ModelId::SONNET_4_6)
675 .max_tokens(8)
676 .user("hi")
677 .assistant("Sure, ")
678 .build()
679 .unwrap();
680 let _ = client.messages().create(req).await.unwrap();
681 }
682
683 #[tokio::test]
684 async fn create_retries_on_overloaded_then_succeeds() {
685 let mock = MockServer::start().await;
686 Mock::given(method("POST"))
687 .and(path("/v1/messages"))
688 .respond_with(ResponseTemplate::new(529))
689 .up_to_n_times(1)
690 .mount(&mock)
691 .await;
692 Mock::given(method("POST"))
693 .and(path("/v1/messages"))
694 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response_body()))
695 .mount(&mock)
696 .await;
697
698 let client = Client::builder()
700 .api_key("sk-ant-x")
701 .base_url(mock.uri())
702 .retry(crate::retry::RetryPolicy {
703 max_attempts: 3,
704 initial_backoff: std::time::Duration::from_millis(1),
705 max_backoff: std::time::Duration::from_millis(5),
706 jitter: crate::retry::Jitter::None,
707 respect_retry_after: false,
708 })
709 .build()
710 .unwrap();
711
712 let req = CreateMessageRequest::builder()
713 .model(ModelId::SONNET_4_6)
714 .max_tokens(8)
715 .user("x")
716 .build()
717 .unwrap();
718 let resp = client.messages().create(req).await.unwrap();
719 assert_eq!(resp.id, "msg_test");
720 assert_eq!(mock.received_requests().await.unwrap().len(), 2);
721 }
722
723 #[test]
724 fn messages_namespace_borrows_client() {
725 let client = Client::new("sk-ant-x");
728 {
729 let _m = client.messages();
730 }
731 let _ = client.messages();
732
733 let _: MessageInput = MessageInput::user("x");
735 }
736
737 #[cfg(feature = "streaming")]
742 fn sse_corpus() -> &'static str {
743 concat!(
745 "event: message_start\n",
746 "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_S\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"usage\":{\"input_tokens\":3,\"output_tokens\":0}}}\n",
747 "\n",
748 "event: content_block_start\n",
749 "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n",
750 "\n",
751 "event: ping\n",
752 "data: {\"type\":\"ping\"}\n",
753 "\n",
754 "event: content_block_delta\n",
755 "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n",
756 "\n",
757 "event: content_block_delta\n",
758 "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n",
759 "\n",
760 "event: content_block_stop\n",
761 "data: {\"type\":\"content_block_stop\",\"index\":0}\n",
762 "\n",
763 "event: message_delta\n",
764 "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}\n",
765 "\n",
766 "event: message_stop\n",
767 "data: {\"type\":\"message_stop\"}\n",
768 "\n",
769 )
770 }
771
772 #[cfg(feature = "streaming")]
773 #[tokio::test]
774 async fn create_stream_aggregates_to_full_message() {
775 use crate::messages::content::{ContentBlock, KnownBlock};
776 use crate::messages::stream::EventStream;
777
778 let mock = MockServer::start().await;
779 Mock::given(method("POST"))
780 .and(path("/v1/messages"))
781 .and(body_partial_json(json!({"stream": true})))
782 .respond_with(
783 ResponseTemplate::new(200)
784 .insert_header("content-type", "text/event-stream")
785 .set_body_string(sse_corpus()),
786 )
787 .mount(&mock)
788 .await;
789
790 let client = client_for(&mock);
791 let req = CreateMessageRequest::builder()
792 .model(ModelId::SONNET_4_6)
793 .max_tokens(8)
794 .user("hi")
795 .build()
796 .unwrap();
797
798 let stream: EventStream = client.messages().create_stream(req).await.unwrap();
799 let msg = stream.aggregate().await.unwrap();
800
801 assert_eq!(msg.id, "msg_S");
802 assert_eq!(msg.stop_reason, Some(StopReason::EndTurn));
803 assert_eq!(msg.usage.output_tokens, 2);
804 assert_eq!(msg.content.len(), 1);
805 match &msg.content[0] {
806 ContentBlock::Known(KnownBlock::Text { text, .. }) => {
807 assert_eq!(text, "Hello world");
808 }
809 _ => panic!("expected text block"),
810 }
811 }
812
813 #[cfg(feature = "streaming")]
814 #[tokio::test]
815 async fn create_stream_yields_individual_events_for_iterator_use() {
816 use futures_util::StreamExt;
817
818 let mock = MockServer::start().await;
819 Mock::given(method("POST"))
820 .and(path("/v1/messages"))
821 .respond_with(
822 ResponseTemplate::new(200)
823 .insert_header("content-type", "text/event-stream")
824 .set_body_string(sse_corpus()),
825 )
826 .mount(&mock)
827 .await;
828
829 let client = client_for(&mock);
830 let req = CreateMessageRequest::builder()
831 .model(ModelId::SONNET_4_6)
832 .max_tokens(8)
833 .user("hi")
834 .build()
835 .unwrap();
836
837 let mut stream = client.messages().create_stream(req).await.unwrap();
838 let mut count = 0;
839 let mut saw_message_stop = false;
840 while let Some(ev) = stream.next().await {
841 let ev = ev.unwrap();
842 count += 1;
843 if ev.type_tag() == Some("message_stop") {
844 saw_message_stop = true;
845 }
846 }
847 assert!(saw_message_stop, "expected to see message_stop event");
848 assert!(count >= 7, "expected at least 7 events, got {count}");
849 }
850
851 #[cfg(feature = "streaming")]
852 #[tokio::test]
853 async fn create_stream_propagates_connect_error() {
854 let mock = MockServer::start().await;
855 Mock::given(method("POST"))
856 .and(path("/v1/messages"))
857 .respond_with(
858 ResponseTemplate::new(401)
859 .insert_header("request-id", "req_unauth")
860 .set_body_json(json!({
861 "type": "error",
862 "error": {"type": "authentication_error", "message": "bad key"}
863 })),
864 )
865 .mount(&mock)
866 .await;
867
868 let client = client_for(&mock);
869 let req = CreateMessageRequest::builder()
870 .model(ModelId::SONNET_4_6)
871 .max_tokens(8)
872 .user("hi")
873 .build()
874 .unwrap();
875
876 let err = client.messages().create_stream(req).await.unwrap_err();
877 assert_eq!(err.status(), Some(http::StatusCode::UNAUTHORIZED));
878 assert_eq!(err.request_id(), Some("req_unauth"));
879 }
880
881 #[cfg(feature = "streaming")]
882 #[tokio::test]
883 async fn create_stream_sets_stream_true_in_request_body() {
884 let mock = MockServer::start().await;
885 Mock::given(method("POST"))
886 .and(path("/v1/messages"))
887 .and(body_partial_json(json!({"stream": true})))
888 .respond_with(
889 ResponseTemplate::new(200)
890 .insert_header("content-type", "text/event-stream")
891 .set_body_string(sse_corpus()),
892 )
893 .mount(&mock)
894 .await;
895
896 let client = client_for(&mock);
897 let req = CreateMessageRequest::builder()
898 .model(ModelId::SONNET_4_6)
899 .max_tokens(8)
900 .user("x")
901 .build()
902 .unwrap();
903 let _ = client.messages().create_stream(req).await.unwrap();
906 }
907}